mirror of
https://github.com/openziti/zrok.git
synced 2024-11-21 23:53:19 +01:00
better authorization handling
This commit is contained in:
parent
9e0caf192b
commit
069417ade0
@ -27,6 +27,25 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
envId := params.Body.Identity
|
||||
if is, err := str.FindIdentitiesForAccount(int(principal.ID), tx); err == nil {
|
||||
found := false
|
||||
for _, i := range is {
|
||||
if i.ZitiId == envId {
|
||||
logrus.Infof("found identity '%v' for user '%v'", envId, principal.Username)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
logrus.Errorf("identity '%v' not found for user '%v'", envId, principal.Username)
|
||||
return tunnel.NewTunnelUnauthorized().WithPayload("bad environment identity")
|
||||
}
|
||||
} else {
|
||||
logrus.Errorf("error finding identities for account '%v'", principal.Username)
|
||||
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
|
||||
edge, err := edgeClient()
|
||||
if err != nil {
|
||||
logrus.Error(err)
|
||||
@ -42,7 +61,6 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi
|
||||
logrus.Error(err)
|
||||
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
envId := params.Body.Identity
|
||||
if err := createServicePolicyBind(svcName, svcId, envId, edge); err != nil {
|
||||
logrus.Error(err)
|
||||
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/go-openapi/runtime/middleware"
|
||||
"github.com/openziti-test-kitchen/zrok/controller/store"
|
||||
"github.com/openziti-test-kitchen/zrok/rest_model_zrok"
|
||||
"github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/tunnel"
|
||||
"github.com/openziti/edge/rest_management_api_client"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"github.com/openziti/edge/rest_management_api_client/service"
|
||||
"github.com/openziti/edge/rest_management_api_client/service_edge_router_policy"
|
||||
"github.com/openziti/edge/rest_management_api_client/service_policy"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"time"
|
||||
)
|
||||
@ -18,12 +20,42 @@ import (
|
||||
func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Principal) middleware.Responder {
|
||||
logrus.Infof("untunneling for '%v' (%v)", principal.Username, principal.Token)
|
||||
|
||||
tx, err := str.Begin()
|
||||
if err != nil {
|
||||
logrus.Errorf("error starting transaction: %v", err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
edge, err := edgeClient()
|
||||
if err != nil {
|
||||
logrus.Error(err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
svcName := params.Body.Service
|
||||
svcId, err := findServiceId(svcName, edge)
|
||||
if err != nil {
|
||||
logrus.Error(err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
var ssvc *store.Service
|
||||
if svcs, err := str.FindServicesForAccount(int(principal.ID), tx); err == nil {
|
||||
for _, svc := range svcs {
|
||||
if svc.ZitiId == svcId {
|
||||
ssvc = svc
|
||||
break
|
||||
}
|
||||
}
|
||||
if ssvc == nil {
|
||||
err := errors.Errorf("service with id '%v' not found for '%v'", svcId, principal.Username)
|
||||
logrus.Error(err)
|
||||
return tunnel.NewUntunnelNotFound().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
} else {
|
||||
logrus.Errorf("error finding services for account '%v'", principal.Username)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
|
||||
if err := deleteEdgeRouterPolicy(svcName, edge); err != nil {
|
||||
logrus.Error(err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
@ -40,47 +72,47 @@ func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Pr
|
||||
logrus.Error(err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
svcId, err := deleteService(svcName, edge)
|
||||
if err != nil {
|
||||
if err := deleteService(svcId, edge); err != nil {
|
||||
logrus.Error(err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
|
||||
logrus.Infof("deallocated service '%v'", svcName)
|
||||
|
||||
tx, err := str.Begin()
|
||||
if err != nil {
|
||||
logrus.Errorf("error starting transaction: %v", err)
|
||||
ssvc.Active = false
|
||||
if err := str.UpdateService(ssvc, tx); err != nil {
|
||||
logrus.Errorf("error deactivating service '%v': %v", svcId, err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
svcs, err := str.FindServicesForAccount(int(principal.ID), tx)
|
||||
if err != nil {
|
||||
logrus.Errorf("error finding services for account: %v", err)
|
||||
if err := tx.Commit(); err != nil {
|
||||
logrus.Errorf("error committing: %v", err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
changed := false
|
||||
for _, svc := range svcs {
|
||||
if svc.ZitiId == svcId {
|
||||
svc.Active = false
|
||||
if err := str.UpdateService(svc, tx); err != nil {
|
||||
logrus.Errorf("error deactivating service '%v': %v", svcId, err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
changed = true
|
||||
logrus.Infof("deactivated service '%v'", svcId)
|
||||
}
|
||||
}
|
||||
if changed {
|
||||
if err := tx.Commit(); err != nil {
|
||||
logrus.Errorf("error committing: %v", err)
|
||||
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
return tunnel.NewUntunnelOK()
|
||||
}
|
||||
|
||||
func findServiceId(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) {
|
||||
filter := fmt.Sprintf("name=\"%v\"", svcName)
|
||||
limit := int64(1)
|
||||
offset := int64(0)
|
||||
listReq := &service.ListServicesParams{
|
||||
Filter: &filter,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
Context: context.Background(),
|
||||
}
|
||||
listReq.SetTimeout(30 * time.Second)
|
||||
listResp, err := edge.Service.ListServices(listReq, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(listResp.Payload.Data) == 1 {
|
||||
return *(listResp.Payload.Data[0].ID), nil
|
||||
}
|
||||
return "", errors.Errorf("service '%v' not found", svcName)
|
||||
}
|
||||
|
||||
func deleteEdgeRouterPolicy(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) error {
|
||||
filter := fmt.Sprintf("name=\"%v\"", svcName)
|
||||
limit := int64(1)
|
||||
@ -187,36 +219,16 @@ func deleteServicePolicy(filter string, edge *rest_management_api_client.ZitiEdg
|
||||
return nil
|
||||
}
|
||||
|
||||
func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) {
|
||||
filter := fmt.Sprintf("name=\"%v\"", svcName)
|
||||
limit := int64(1)
|
||||
offset := int64(0)
|
||||
listReq := &service.ListServicesParams{
|
||||
Filter: &filter,
|
||||
Limit: &limit,
|
||||
Offset: &offset,
|
||||
func deleteService(svcId string, edge *rest_management_api_client.ZitiEdgeManagement) error {
|
||||
req := &service.DeleteServiceParams{
|
||||
ID: svcId,
|
||||
Context: context.Background(),
|
||||
}
|
||||
listReq.SetTimeout(30 * time.Second)
|
||||
listResp, err := edge.Service.ListServices(listReq, nil)
|
||||
req.SetTimeout(30 * time.Second)
|
||||
_, err := edge.Service.DeleteService(req, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
if len(listResp.Payload.Data) == 1 {
|
||||
svcId := *(listResp.Payload.Data[0].ID)
|
||||
req := &service.DeleteServiceParams{
|
||||
ID: svcId,
|
||||
Context: context.Background(),
|
||||
}
|
||||
req.SetTimeout(30 * time.Second)
|
||||
_, err := edge.Service.DeleteService(req, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
logrus.Infof("deleted service '%v'", svcId)
|
||||
return svcId, nil
|
||||
} else {
|
||||
logrus.Infof("did not find a service")
|
||||
}
|
||||
return "", nil
|
||||
logrus.Infof("deleted service '%v'", svcId)
|
||||
return nil
|
||||
}
|
||||
|
@ -29,6 +29,12 @@ func (o *TunnelReader) ReadResponse(response runtime.ClientResponse, consumer ru
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
case 401:
|
||||
result := NewTunnelUnauthorized()
|
||||
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, result
|
||||
case 500:
|
||||
result := NewTunnelInternalServerError()
|
||||
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
||||
@ -72,6 +78,36 @@ func (o *TunnelCreated) readResponse(response runtime.ClientResponse, consumer r
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTunnelUnauthorized creates a TunnelUnauthorized with default headers values
|
||||
func NewTunnelUnauthorized() *TunnelUnauthorized {
|
||||
return &TunnelUnauthorized{}
|
||||
}
|
||||
|
||||
/* TunnelUnauthorized describes a response with status code 401, with default header values.
|
||||
|
||||
invalid environment identity
|
||||
*/
|
||||
type TunnelUnauthorized struct {
|
||||
Payload rest_model_zrok.ErrorMessage
|
||||
}
|
||||
|
||||
func (o *TunnelUnauthorized) Error() string {
|
||||
return fmt.Sprintf("[POST /tunnel][%d] tunnelUnauthorized %+v", 401, o.Payload)
|
||||
}
|
||||
func (o *TunnelUnauthorized) GetPayload() rest_model_zrok.ErrorMessage {
|
||||
return o.Payload
|
||||
}
|
||||
|
||||
func (o *TunnelUnauthorized) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
|
||||
|
||||
// response payload
|
||||
if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewTunnelInternalServerError creates a TunnelInternalServerError with default headers values
|
||||
func NewTunnelInternalServerError() *TunnelInternalServerError {
|
||||
return &TunnelInternalServerError{}
|
||||
|
@ -29,6 +29,12 @@ func (o *UntunnelReader) ReadResponse(response runtime.ClientResponse, consumer
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
case 404:
|
||||
result := NewUntunnelNotFound()
|
||||
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, result
|
||||
case 500:
|
||||
result := NewUntunnelInternalServerError()
|
||||
if err := result.readResponse(response, consumer, o.formats); err != nil {
|
||||
@ -61,6 +67,36 @@ func (o *UntunnelOK) readResponse(response runtime.ClientResponse, consumer runt
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUntunnelNotFound creates a UntunnelNotFound with default headers values
|
||||
func NewUntunnelNotFound() *UntunnelNotFound {
|
||||
return &UntunnelNotFound{}
|
||||
}
|
||||
|
||||
/* UntunnelNotFound describes a response with status code 404, with default header values.
|
||||
|
||||
not found
|
||||
*/
|
||||
type UntunnelNotFound struct {
|
||||
Payload rest_model_zrok.ErrorMessage
|
||||
}
|
||||
|
||||
func (o *UntunnelNotFound) Error() string {
|
||||
return fmt.Sprintf("[DELETE /untunnel][%d] untunnelNotFound %+v", 404, o.Payload)
|
||||
}
|
||||
func (o *UntunnelNotFound) GetPayload() rest_model_zrok.ErrorMessage {
|
||||
return o.Payload
|
||||
}
|
||||
|
||||
func (o *UntunnelNotFound) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
|
||||
|
||||
// response payload
|
||||
if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUntunnelInternalServerError creates a UntunnelInternalServerError with default headers values
|
||||
func NewUntunnelInternalServerError() *UntunnelInternalServerError {
|
||||
return &UntunnelInternalServerError{}
|
||||
|
@ -131,6 +131,12 @@ func init() {
|
||||
"$ref": "#/definitions/tunnelResponse"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "invalid environment identity",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/errorMessage"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "internal server error",
|
||||
"schema": {
|
||||
@ -164,6 +170,12 @@ func init() {
|
||||
"200": {
|
||||
"description": "tunnel removed"
|
||||
},
|
||||
"404": {
|
||||
"description": "not found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/errorMessage"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "internal server error",
|
||||
"schema": {
|
||||
@ -391,6 +403,12 @@ func init() {
|
||||
"$ref": "#/definitions/tunnelResponse"
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"description": "invalid environment identity",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/errorMessage"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "internal server error",
|
||||
"schema": {
|
||||
@ -424,6 +442,12 @@ func init() {
|
||||
"200": {
|
||||
"description": "tunnel removed"
|
||||
},
|
||||
"404": {
|
||||
"description": "not found",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/errorMessage"
|
||||
}
|
||||
},
|
||||
"500": {
|
||||
"description": "internal server error",
|
||||
"schema": {
|
||||
|
@ -57,6 +57,48 @@ func (o *TunnelCreated) WriteResponse(rw http.ResponseWriter, producer runtime.P
|
||||
}
|
||||
}
|
||||
|
||||
// TunnelUnauthorizedCode is the HTTP code returned for type TunnelUnauthorized
|
||||
const TunnelUnauthorizedCode int = 401
|
||||
|
||||
/*TunnelUnauthorized invalid environment identity
|
||||
|
||||
swagger:response tunnelUnauthorized
|
||||
*/
|
||||
type TunnelUnauthorized struct {
|
||||
|
||||
/*
|
||||
In: Body
|
||||
*/
|
||||
Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// NewTunnelUnauthorized creates TunnelUnauthorized with default headers values
|
||||
func NewTunnelUnauthorized() *TunnelUnauthorized {
|
||||
|
||||
return &TunnelUnauthorized{}
|
||||
}
|
||||
|
||||
// WithPayload adds the payload to the tunnel unauthorized response
|
||||
func (o *TunnelUnauthorized) WithPayload(payload rest_model_zrok.ErrorMessage) *TunnelUnauthorized {
|
||||
o.Payload = payload
|
||||
return o
|
||||
}
|
||||
|
||||
// SetPayload sets the payload to the tunnel unauthorized response
|
||||
func (o *TunnelUnauthorized) SetPayload(payload rest_model_zrok.ErrorMessage) {
|
||||
o.Payload = payload
|
||||
}
|
||||
|
||||
// WriteResponse to the client
|
||||
func (o *TunnelUnauthorized) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
|
||||
|
||||
rw.WriteHeader(401)
|
||||
payload := o.Payload
|
||||
if err := producer.Produce(rw, payload); err != nil {
|
||||
panic(err) // let the recovery middleware deal with this
|
||||
}
|
||||
}
|
||||
|
||||
// TunnelInternalServerErrorCode is the HTTP code returned for type TunnelInternalServerError
|
||||
const TunnelInternalServerErrorCode int = 500
|
||||
|
||||
|
@ -37,6 +37,48 @@ func (o *UntunnelOK) WriteResponse(rw http.ResponseWriter, producer runtime.Prod
|
||||
rw.WriteHeader(200)
|
||||
}
|
||||
|
||||
// UntunnelNotFoundCode is the HTTP code returned for type UntunnelNotFound
|
||||
const UntunnelNotFoundCode int = 404
|
||||
|
||||
/*UntunnelNotFound not found
|
||||
|
||||
swagger:response untunnelNotFound
|
||||
*/
|
||||
type UntunnelNotFound struct {
|
||||
|
||||
/*
|
||||
In: Body
|
||||
*/
|
||||
Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
// NewUntunnelNotFound creates UntunnelNotFound with default headers values
|
||||
func NewUntunnelNotFound() *UntunnelNotFound {
|
||||
|
||||
return &UntunnelNotFound{}
|
||||
}
|
||||
|
||||
// WithPayload adds the payload to the untunnel not found response
|
||||
func (o *UntunnelNotFound) WithPayload(payload rest_model_zrok.ErrorMessage) *UntunnelNotFound {
|
||||
o.Payload = payload
|
||||
return o
|
||||
}
|
||||
|
||||
// SetPayload sets the payload to the untunnel not found response
|
||||
func (o *UntunnelNotFound) SetPayload(payload rest_model_zrok.ErrorMessage) {
|
||||
o.Payload = payload
|
||||
}
|
||||
|
||||
// WriteResponse to the client
|
||||
func (o *UntunnelNotFound) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
|
||||
|
||||
rw.WriteHeader(404)
|
||||
payload := o.Payload
|
||||
if err := producer.Produce(rw, payload); err != nil {
|
||||
panic(err) // let the recovery middleware deal with this
|
||||
}
|
||||
}
|
||||
|
||||
// UntunnelInternalServerErrorCode is the HTTP code returned for type UntunnelInternalServerError
|
||||
const UntunnelInternalServerErrorCode int = 500
|
||||
|
||||
|
@ -70,6 +70,10 @@ paths:
|
||||
description: tunnel created
|
||||
schema:
|
||||
$ref: "#/definitions/tunnelResponse"
|
||||
401:
|
||||
description: invalid environment identity
|
||||
schema:
|
||||
$ref: "#/definitions/errorMessage"
|
||||
500:
|
||||
description: internal server error
|
||||
schema:
|
||||
@ -89,6 +93,10 @@ paths:
|
||||
responses:
|
||||
200:
|
||||
description: tunnel removed
|
||||
404:
|
||||
description: not found
|
||||
schema:
|
||||
$ref: "#/definitions/errorMessage"
|
||||
500:
|
||||
description: internal server error
|
||||
schema:
|
||||
|
Loading…
Reference in New Issue
Block a user