better authorization handling

This commit is contained in:
Michael Quigley 2022-08-01 15:44:26 -04:00
parent 9e0caf192b
commit 069417ade0
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
8 changed files with 274 additions and 56 deletions

View File

@ -27,6 +27,25 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi
}
defer func() { _ = tx.Rollback() }()
envId := params.Body.Identity
if is, err := str.FindIdentitiesForAccount(int(principal.ID), tx); err == nil {
found := false
for _, i := range is {
if i.ZitiId == envId {
logrus.Infof("found identity '%v' for user '%v'", envId, principal.Username)
found = true
break
}
}
if !found {
logrus.Errorf("identity '%v' not found for user '%v'", envId, principal.Username)
return tunnel.NewTunnelUnauthorized().WithPayload("bad environment identity")
}
} else {
logrus.Errorf("error finding identities for account '%v'", principal.Username)
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
edge, err := edgeClient()
if err != nil {
logrus.Error(err)
@ -42,7 +61,6 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi
logrus.Error(err)
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
envId := params.Body.Identity
if err := createServicePolicyBind(svcName, svcId, envId, edge); err != nil {
logrus.Error(err)
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"github.com/go-openapi/runtime/middleware"
"github.com/openziti-test-kitchen/zrok/controller/store"
"github.com/openziti-test-kitchen/zrok/rest_model_zrok"
"github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/tunnel"
"github.com/openziti/edge/rest_management_api_client"
@ -11,6 +12,7 @@ import (
"github.com/openziti/edge/rest_management_api_client/service"
"github.com/openziti/edge/rest_management_api_client/service_edge_router_policy"
"github.com/openziti/edge/rest_management_api_client/service_policy"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"time"
)
@ -18,12 +20,42 @@ import (
func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Principal) middleware.Responder {
logrus.Infof("untunneling for '%v' (%v)", principal.Username, principal.Token)
tx, err := str.Begin()
if err != nil {
logrus.Errorf("error starting transaction: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
defer func() { _ = tx.Rollback() }()
edge, err := edgeClient()
if err != nil {
logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
svcName := params.Body.Service
svcId, err := findServiceId(svcName, edge)
if err != nil {
logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
var ssvc *store.Service
if svcs, err := str.FindServicesForAccount(int(principal.ID), tx); err == nil {
for _, svc := range svcs {
if svc.ZitiId == svcId {
ssvc = svc
break
}
}
if ssvc == nil {
err := errors.Errorf("service with id '%v' not found for '%v'", svcId, principal.Username)
logrus.Error(err)
return tunnel.NewUntunnelNotFound().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
} else {
logrus.Errorf("error finding services for account '%v'", principal.Username)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
if err := deleteEdgeRouterPolicy(svcName, edge); err != nil {
logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
@ -40,47 +72,47 @@ func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Pr
logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
svcId, err := deleteService(svcName, edge)
if err != nil {
if err := deleteService(svcId, edge); err != nil {
logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
logrus.Infof("deallocated service '%v'", svcName)
tx, err := str.Begin()
if err != nil {
logrus.Errorf("error starting transaction: %v", err)
ssvc.Active = false
if err := str.UpdateService(ssvc, tx); err != nil {
logrus.Errorf("error deactivating service '%v': %v", svcId, err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
defer func() { _ = tx.Rollback() }()
svcs, err := str.FindServicesForAccount(int(principal.ID), tx)
if err != nil {
logrus.Errorf("error finding services for account: %v", err)
if err := tx.Commit(); err != nil {
logrus.Errorf("error committing: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
changed := false
for _, svc := range svcs {
if svc.ZitiId == svcId {
svc.Active = false
if err := str.UpdateService(svc, tx); err != nil {
logrus.Errorf("error deactivating service '%v': %v", svcId, err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
changed = true
logrus.Infof("deactivated service '%v'", svcId)
}
}
if changed {
if err := tx.Commit(); err != nil {
logrus.Errorf("error committing: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
}
return tunnel.NewUntunnelOK()
}
func findServiceId(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) {
filter := fmt.Sprintf("name=\"%v\"", svcName)
limit := int64(1)
offset := int64(0)
listReq := &service.ListServicesParams{
Filter: &filter,
Limit: &limit,
Offset: &offset,
Context: context.Background(),
}
listReq.SetTimeout(30 * time.Second)
listResp, err := edge.Service.ListServices(listReq, nil)
if err != nil {
return "", err
}
if len(listResp.Payload.Data) == 1 {
return *(listResp.Payload.Data[0].ID), nil
}
return "", errors.Errorf("service '%v' not found", svcName)
}
func deleteEdgeRouterPolicy(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) error {
filter := fmt.Sprintf("name=\"%v\"", svcName)
limit := int64(1)
@ -187,36 +219,16 @@ func deleteServicePolicy(filter string, edge *rest_management_api_client.ZitiEdg
return nil
}
func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) {
filter := fmt.Sprintf("name=\"%v\"", svcName)
limit := int64(1)
offset := int64(0)
listReq := &service.ListServicesParams{
Filter: &filter,
Limit: &limit,
Offset: &offset,
func deleteService(svcId string, edge *rest_management_api_client.ZitiEdgeManagement) error {
req := &service.DeleteServiceParams{
ID: svcId,
Context: context.Background(),
}
listReq.SetTimeout(30 * time.Second)
listResp, err := edge.Service.ListServices(listReq, nil)
req.SetTimeout(30 * time.Second)
_, err := edge.Service.DeleteService(req, nil)
if err != nil {
return "", err
return err
}
if len(listResp.Payload.Data) == 1 {
svcId := *(listResp.Payload.Data[0].ID)
req := &service.DeleteServiceParams{
ID: svcId,
Context: context.Background(),
}
req.SetTimeout(30 * time.Second)
_, err := edge.Service.DeleteService(req, nil)
if err != nil {
return "", err
}
logrus.Infof("deleted service '%v'", svcId)
return svcId, nil
} else {
logrus.Infof("did not find a service")
}
return "", nil
logrus.Infof("deleted service '%v'", svcId)
return nil
}

View File

@ -29,6 +29,12 @@ func (o *TunnelReader) ReadResponse(response runtime.ClientResponse, consumer ru
return nil, err
}
return result, nil
case 401:
result := NewTunnelUnauthorized()
if err := result.readResponse(response, consumer, o.formats); err != nil {
return nil, err
}
return nil, result
case 500:
result := NewTunnelInternalServerError()
if err := result.readResponse(response, consumer, o.formats); err != nil {
@ -72,6 +78,36 @@ func (o *TunnelCreated) readResponse(response runtime.ClientResponse, consumer r
return nil
}
// NewTunnelUnauthorized creates a TunnelUnauthorized with default headers values
func NewTunnelUnauthorized() *TunnelUnauthorized {
return &TunnelUnauthorized{}
}
/* TunnelUnauthorized describes a response with status code 401, with default header values.
invalid environment identity
*/
type TunnelUnauthorized struct {
Payload rest_model_zrok.ErrorMessage
}
func (o *TunnelUnauthorized) Error() string {
return fmt.Sprintf("[POST /tunnel][%d] tunnelUnauthorized %+v", 401, o.Payload)
}
func (o *TunnelUnauthorized) GetPayload() rest_model_zrok.ErrorMessage {
return o.Payload
}
func (o *TunnelUnauthorized) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
// response payload
if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF {
return err
}
return nil
}
// NewTunnelInternalServerError creates a TunnelInternalServerError with default headers values
func NewTunnelInternalServerError() *TunnelInternalServerError {
return &TunnelInternalServerError{}

View File

@ -29,6 +29,12 @@ func (o *UntunnelReader) ReadResponse(response runtime.ClientResponse, consumer
return nil, err
}
return result, nil
case 404:
result := NewUntunnelNotFound()
if err := result.readResponse(response, consumer, o.formats); err != nil {
return nil, err
}
return nil, result
case 500:
result := NewUntunnelInternalServerError()
if err := result.readResponse(response, consumer, o.formats); err != nil {
@ -61,6 +67,36 @@ func (o *UntunnelOK) readResponse(response runtime.ClientResponse, consumer runt
return nil
}
// NewUntunnelNotFound creates a UntunnelNotFound with default headers values
func NewUntunnelNotFound() *UntunnelNotFound {
return &UntunnelNotFound{}
}
/* UntunnelNotFound describes a response with status code 404, with default header values.
not found
*/
type UntunnelNotFound struct {
Payload rest_model_zrok.ErrorMessage
}
func (o *UntunnelNotFound) Error() string {
return fmt.Sprintf("[DELETE /untunnel][%d] untunnelNotFound %+v", 404, o.Payload)
}
func (o *UntunnelNotFound) GetPayload() rest_model_zrok.ErrorMessage {
return o.Payload
}
func (o *UntunnelNotFound) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
// response payload
if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF {
return err
}
return nil
}
// NewUntunnelInternalServerError creates a UntunnelInternalServerError with default headers values
func NewUntunnelInternalServerError() *UntunnelInternalServerError {
return &UntunnelInternalServerError{}

View File

@ -131,6 +131,12 @@ func init() {
"$ref": "#/definitions/tunnelResponse"
}
},
"401": {
"description": "invalid environment identity",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": {
"description": "internal server error",
"schema": {
@ -164,6 +170,12 @@ func init() {
"200": {
"description": "tunnel removed"
},
"404": {
"description": "not found",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": {
"description": "internal server error",
"schema": {
@ -391,6 +403,12 @@ func init() {
"$ref": "#/definitions/tunnelResponse"
}
},
"401": {
"description": "invalid environment identity",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": {
"description": "internal server error",
"schema": {
@ -424,6 +442,12 @@ func init() {
"200": {
"description": "tunnel removed"
},
"404": {
"description": "not found",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": {
"description": "internal server error",
"schema": {

View File

@ -57,6 +57,48 @@ func (o *TunnelCreated) WriteResponse(rw http.ResponseWriter, producer runtime.P
}
}
// TunnelUnauthorizedCode is the HTTP code returned for type TunnelUnauthorized
const TunnelUnauthorizedCode int = 401
/*TunnelUnauthorized invalid environment identity
swagger:response tunnelUnauthorized
*/
type TunnelUnauthorized struct {
/*
In: Body
*/
Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"`
}
// NewTunnelUnauthorized creates TunnelUnauthorized with default headers values
func NewTunnelUnauthorized() *TunnelUnauthorized {
return &TunnelUnauthorized{}
}
// WithPayload adds the payload to the tunnel unauthorized response
func (o *TunnelUnauthorized) WithPayload(payload rest_model_zrok.ErrorMessage) *TunnelUnauthorized {
o.Payload = payload
return o
}
// SetPayload sets the payload to the tunnel unauthorized response
func (o *TunnelUnauthorized) SetPayload(payload rest_model_zrok.ErrorMessage) {
o.Payload = payload
}
// WriteResponse to the client
func (o *TunnelUnauthorized) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
rw.WriteHeader(401)
payload := o.Payload
if err := producer.Produce(rw, payload); err != nil {
panic(err) // let the recovery middleware deal with this
}
}
// TunnelInternalServerErrorCode is the HTTP code returned for type TunnelInternalServerError
const TunnelInternalServerErrorCode int = 500

View File

@ -37,6 +37,48 @@ func (o *UntunnelOK) WriteResponse(rw http.ResponseWriter, producer runtime.Prod
rw.WriteHeader(200)
}
// UntunnelNotFoundCode is the HTTP code returned for type UntunnelNotFound
const UntunnelNotFoundCode int = 404
/*UntunnelNotFound not found
swagger:response untunnelNotFound
*/
type UntunnelNotFound struct {
/*
In: Body
*/
Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"`
}
// NewUntunnelNotFound creates UntunnelNotFound with default headers values
func NewUntunnelNotFound() *UntunnelNotFound {
return &UntunnelNotFound{}
}
// WithPayload adds the payload to the untunnel not found response
func (o *UntunnelNotFound) WithPayload(payload rest_model_zrok.ErrorMessage) *UntunnelNotFound {
o.Payload = payload
return o
}
// SetPayload sets the payload to the untunnel not found response
func (o *UntunnelNotFound) SetPayload(payload rest_model_zrok.ErrorMessage) {
o.Payload = payload
}
// WriteResponse to the client
func (o *UntunnelNotFound) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
rw.WriteHeader(404)
payload := o.Payload
if err := producer.Produce(rw, payload); err != nil {
panic(err) // let the recovery middleware deal with this
}
}
// UntunnelInternalServerErrorCode is the HTTP code returned for type UntunnelInternalServerError
const UntunnelInternalServerErrorCode int = 500

View File

@ -70,6 +70,10 @@ paths:
description: tunnel created
schema:
$ref: "#/definitions/tunnelResponse"
401:
description: invalid environment identity
schema:
$ref: "#/definitions/errorMessage"
500:
description: internal server error
schema:
@ -89,6 +93,10 @@ paths:
responses:
200:
description: tunnel removed
404:
description: not found
schema:
$ref: "#/definitions/errorMessage"
500:
description: internal server error
schema: