better authorization handling

This commit is contained in:
Michael Quigley 2022-08-01 15:44:26 -04:00
parent 9e0caf192b
commit 069417ade0
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
8 changed files with 274 additions and 56 deletions

View File

@ -27,6 +27,25 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi
} }
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
envId := params.Body.Identity
if is, err := str.FindIdentitiesForAccount(int(principal.ID), tx); err == nil {
found := false
for _, i := range is {
if i.ZitiId == envId {
logrus.Infof("found identity '%v' for user '%v'", envId, principal.Username)
found = true
break
}
}
if !found {
logrus.Errorf("identity '%v' not found for user '%v'", envId, principal.Username)
return tunnel.NewTunnelUnauthorized().WithPayload("bad environment identity")
}
} else {
logrus.Errorf("error finding identities for account '%v'", principal.Username)
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
edge, err := edgeClient() edge, err := edgeClient()
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
@ -42,7 +61,6 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi
logrus.Error(err) logrus.Error(err)
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
} }
envId := params.Body.Identity
if err := createServicePolicyBind(svcName, svcId, envId, edge); err != nil { if err := createServicePolicyBind(svcName, svcId, envId, edge); err != nil {
logrus.Error(err) logrus.Error(err)
return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/go-openapi/runtime/middleware" "github.com/go-openapi/runtime/middleware"
"github.com/openziti-test-kitchen/zrok/controller/store"
"github.com/openziti-test-kitchen/zrok/rest_model_zrok" "github.com/openziti-test-kitchen/zrok/rest_model_zrok"
"github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/tunnel" "github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/tunnel"
"github.com/openziti/edge/rest_management_api_client" "github.com/openziti/edge/rest_management_api_client"
@ -11,6 +12,7 @@ import (
"github.com/openziti/edge/rest_management_api_client/service" "github.com/openziti/edge/rest_management_api_client/service"
"github.com/openziti/edge/rest_management_api_client/service_edge_router_policy" "github.com/openziti/edge/rest_management_api_client/service_edge_router_policy"
"github.com/openziti/edge/rest_management_api_client/service_policy" "github.com/openziti/edge/rest_management_api_client/service_policy"
"github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"time" "time"
) )
@ -18,12 +20,42 @@ import (
func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Principal) middleware.Responder { func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Principal) middleware.Responder {
logrus.Infof("untunneling for '%v' (%v)", principal.Username, principal.Token) logrus.Infof("untunneling for '%v' (%v)", principal.Username, principal.Token)
tx, err := str.Begin()
if err != nil {
logrus.Errorf("error starting transaction: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
defer func() { _ = tx.Rollback() }()
edge, err := edgeClient() edge, err := edgeClient()
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
} }
svcName := params.Body.Service svcName := params.Body.Service
svcId, err := findServiceId(svcName, edge)
if err != nil {
logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
var ssvc *store.Service
if svcs, err := str.FindServicesForAccount(int(principal.ID), tx); err == nil {
for _, svc := range svcs {
if svc.ZitiId == svcId {
ssvc = svc
break
}
}
if ssvc == nil {
err := errors.Errorf("service with id '%v' not found for '%v'", svcId, principal.Username)
logrus.Error(err)
return tunnel.NewUntunnelNotFound().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
} else {
logrus.Errorf("error finding services for account '%v'", principal.Username)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
if err := deleteEdgeRouterPolicy(svcName, edge); err != nil { if err := deleteEdgeRouterPolicy(svcName, edge); err != nil {
logrus.Error(err) logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
@ -40,47 +72,47 @@ func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Pr
logrus.Error(err) logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
} }
svcId, err := deleteService(svcName, edge) if err := deleteService(svcId, edge); err != nil {
if err != nil {
logrus.Error(err) logrus.Error(err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
} }
logrus.Infof("deallocated service '%v'", svcName) logrus.Infof("deallocated service '%v'", svcName)
tx, err := str.Begin() ssvc.Active = false
if err != nil { if err := str.UpdateService(ssvc, tx); err != nil {
logrus.Errorf("error starting transaction: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
defer func() { _ = tx.Rollback() }()
svcs, err := str.FindServicesForAccount(int(principal.ID), tx)
if err != nil {
logrus.Errorf("error finding services for account: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
}
changed := false
for _, svc := range svcs {
if svc.ZitiId == svcId {
svc.Active = false
if err := str.UpdateService(svc, tx); err != nil {
logrus.Errorf("error deactivating service '%v': %v", svcId, err) logrus.Errorf("error deactivating service '%v': %v", svcId, err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
} }
changed = true
logrus.Infof("deactivated service '%v'", svcId)
}
}
if changed {
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
logrus.Errorf("error committing: %v", err) logrus.Errorf("error committing: %v", err)
return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error()))
} }
}
return tunnel.NewUntunnelOK() return tunnel.NewUntunnelOK()
} }
func findServiceId(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) {
filter := fmt.Sprintf("name=\"%v\"", svcName)
limit := int64(1)
offset := int64(0)
listReq := &service.ListServicesParams{
Filter: &filter,
Limit: &limit,
Offset: &offset,
Context: context.Background(),
}
listReq.SetTimeout(30 * time.Second)
listResp, err := edge.Service.ListServices(listReq, nil)
if err != nil {
return "", err
}
if len(listResp.Payload.Data) == 1 {
return *(listResp.Payload.Data[0].ID), nil
}
return "", errors.Errorf("service '%v' not found", svcName)
}
func deleteEdgeRouterPolicy(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) error { func deleteEdgeRouterPolicy(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) error {
filter := fmt.Sprintf("name=\"%v\"", svcName) filter := fmt.Sprintf("name=\"%v\"", svcName)
limit := int64(1) limit := int64(1)
@ -187,23 +219,7 @@ func deleteServicePolicy(filter string, edge *rest_management_api_client.ZitiEdg
return nil return nil
} }
func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) { func deleteService(svcId string, edge *rest_management_api_client.ZitiEdgeManagement) error {
filter := fmt.Sprintf("name=\"%v\"", svcName)
limit := int64(1)
offset := int64(0)
listReq := &service.ListServicesParams{
Filter: &filter,
Limit: &limit,
Offset: &offset,
Context: context.Background(),
}
listReq.SetTimeout(30 * time.Second)
listResp, err := edge.Service.ListServices(listReq, nil)
if err != nil {
return "", err
}
if len(listResp.Payload.Data) == 1 {
svcId := *(listResp.Payload.Data[0].ID)
req := &service.DeleteServiceParams{ req := &service.DeleteServiceParams{
ID: svcId, ID: svcId,
Context: context.Background(), Context: context.Background(),
@ -211,12 +227,8 @@ func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeMana
req.SetTimeout(30 * time.Second) req.SetTimeout(30 * time.Second)
_, err := edge.Service.DeleteService(req, nil) _, err := edge.Service.DeleteService(req, nil)
if err != nil { if err != nil {
return "", err return err
} }
logrus.Infof("deleted service '%v'", svcId) logrus.Infof("deleted service '%v'", svcId)
return svcId, nil return nil
} else {
logrus.Infof("did not find a service")
}
return "", nil
} }

View File

@ -29,6 +29,12 @@ func (o *TunnelReader) ReadResponse(response runtime.ClientResponse, consumer ru
return nil, err return nil, err
} }
return result, nil return result, nil
case 401:
result := NewTunnelUnauthorized()
if err := result.readResponse(response, consumer, o.formats); err != nil {
return nil, err
}
return nil, result
case 500: case 500:
result := NewTunnelInternalServerError() result := NewTunnelInternalServerError()
if err := result.readResponse(response, consumer, o.formats); err != nil { if err := result.readResponse(response, consumer, o.formats); err != nil {
@ -72,6 +78,36 @@ func (o *TunnelCreated) readResponse(response runtime.ClientResponse, consumer r
return nil return nil
} }
// NewTunnelUnauthorized creates a TunnelUnauthorized with default headers values
func NewTunnelUnauthorized() *TunnelUnauthorized {
return &TunnelUnauthorized{}
}
/* TunnelUnauthorized describes a response with status code 401, with default header values.
invalid environment identity
*/
type TunnelUnauthorized struct {
Payload rest_model_zrok.ErrorMessage
}
func (o *TunnelUnauthorized) Error() string {
return fmt.Sprintf("[POST /tunnel][%d] tunnelUnauthorized %+v", 401, o.Payload)
}
func (o *TunnelUnauthorized) GetPayload() rest_model_zrok.ErrorMessage {
return o.Payload
}
func (o *TunnelUnauthorized) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
// response payload
if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF {
return err
}
return nil
}
// NewTunnelInternalServerError creates a TunnelInternalServerError with default headers values // NewTunnelInternalServerError creates a TunnelInternalServerError with default headers values
func NewTunnelInternalServerError() *TunnelInternalServerError { func NewTunnelInternalServerError() *TunnelInternalServerError {
return &TunnelInternalServerError{} return &TunnelInternalServerError{}

View File

@ -29,6 +29,12 @@ func (o *UntunnelReader) ReadResponse(response runtime.ClientResponse, consumer
return nil, err return nil, err
} }
return result, nil return result, nil
case 404:
result := NewUntunnelNotFound()
if err := result.readResponse(response, consumer, o.formats); err != nil {
return nil, err
}
return nil, result
case 500: case 500:
result := NewUntunnelInternalServerError() result := NewUntunnelInternalServerError()
if err := result.readResponse(response, consumer, o.formats); err != nil { if err := result.readResponse(response, consumer, o.formats); err != nil {
@ -61,6 +67,36 @@ func (o *UntunnelOK) readResponse(response runtime.ClientResponse, consumer runt
return nil return nil
} }
// NewUntunnelNotFound creates a UntunnelNotFound with default headers values
func NewUntunnelNotFound() *UntunnelNotFound {
return &UntunnelNotFound{}
}
/* UntunnelNotFound describes a response with status code 404, with default header values.
not found
*/
type UntunnelNotFound struct {
Payload rest_model_zrok.ErrorMessage
}
func (o *UntunnelNotFound) Error() string {
return fmt.Sprintf("[DELETE /untunnel][%d] untunnelNotFound %+v", 404, o.Payload)
}
func (o *UntunnelNotFound) GetPayload() rest_model_zrok.ErrorMessage {
return o.Payload
}
func (o *UntunnelNotFound) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error {
// response payload
if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF {
return err
}
return nil
}
// NewUntunnelInternalServerError creates a UntunnelInternalServerError with default headers values // NewUntunnelInternalServerError creates a UntunnelInternalServerError with default headers values
func NewUntunnelInternalServerError() *UntunnelInternalServerError { func NewUntunnelInternalServerError() *UntunnelInternalServerError {
return &UntunnelInternalServerError{} return &UntunnelInternalServerError{}

View File

@ -131,6 +131,12 @@ func init() {
"$ref": "#/definitions/tunnelResponse" "$ref": "#/definitions/tunnelResponse"
} }
}, },
"401": {
"description": "invalid environment identity",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": { "500": {
"description": "internal server error", "description": "internal server error",
"schema": { "schema": {
@ -164,6 +170,12 @@ func init() {
"200": { "200": {
"description": "tunnel removed" "description": "tunnel removed"
}, },
"404": {
"description": "not found",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": { "500": {
"description": "internal server error", "description": "internal server error",
"schema": { "schema": {
@ -391,6 +403,12 @@ func init() {
"$ref": "#/definitions/tunnelResponse" "$ref": "#/definitions/tunnelResponse"
} }
}, },
"401": {
"description": "invalid environment identity",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": { "500": {
"description": "internal server error", "description": "internal server error",
"schema": { "schema": {
@ -424,6 +442,12 @@ func init() {
"200": { "200": {
"description": "tunnel removed" "description": "tunnel removed"
}, },
"404": {
"description": "not found",
"schema": {
"$ref": "#/definitions/errorMessage"
}
},
"500": { "500": {
"description": "internal server error", "description": "internal server error",
"schema": { "schema": {

View File

@ -57,6 +57,48 @@ func (o *TunnelCreated) WriteResponse(rw http.ResponseWriter, producer runtime.P
} }
} }
// TunnelUnauthorizedCode is the HTTP code returned for type TunnelUnauthorized
const TunnelUnauthorizedCode int = 401
/*TunnelUnauthorized invalid environment identity
swagger:response tunnelUnauthorized
*/
type TunnelUnauthorized struct {
/*
In: Body
*/
Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"`
}
// NewTunnelUnauthorized creates TunnelUnauthorized with default headers values
func NewTunnelUnauthorized() *TunnelUnauthorized {
return &TunnelUnauthorized{}
}
// WithPayload adds the payload to the tunnel unauthorized response
func (o *TunnelUnauthorized) WithPayload(payload rest_model_zrok.ErrorMessage) *TunnelUnauthorized {
o.Payload = payload
return o
}
// SetPayload sets the payload to the tunnel unauthorized response
func (o *TunnelUnauthorized) SetPayload(payload rest_model_zrok.ErrorMessage) {
o.Payload = payload
}
// WriteResponse to the client
func (o *TunnelUnauthorized) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
rw.WriteHeader(401)
payload := o.Payload
if err := producer.Produce(rw, payload); err != nil {
panic(err) // let the recovery middleware deal with this
}
}
// TunnelInternalServerErrorCode is the HTTP code returned for type TunnelInternalServerError // TunnelInternalServerErrorCode is the HTTP code returned for type TunnelInternalServerError
const TunnelInternalServerErrorCode int = 500 const TunnelInternalServerErrorCode int = 500

View File

@ -37,6 +37,48 @@ func (o *UntunnelOK) WriteResponse(rw http.ResponseWriter, producer runtime.Prod
rw.WriteHeader(200) rw.WriteHeader(200)
} }
// UntunnelNotFoundCode is the HTTP code returned for type UntunnelNotFound
const UntunnelNotFoundCode int = 404
/*UntunnelNotFound not found
swagger:response untunnelNotFound
*/
type UntunnelNotFound struct {
/*
In: Body
*/
Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"`
}
// NewUntunnelNotFound creates UntunnelNotFound with default headers values
func NewUntunnelNotFound() *UntunnelNotFound {
return &UntunnelNotFound{}
}
// WithPayload adds the payload to the untunnel not found response
func (o *UntunnelNotFound) WithPayload(payload rest_model_zrok.ErrorMessage) *UntunnelNotFound {
o.Payload = payload
return o
}
// SetPayload sets the payload to the untunnel not found response
func (o *UntunnelNotFound) SetPayload(payload rest_model_zrok.ErrorMessage) {
o.Payload = payload
}
// WriteResponse to the client
func (o *UntunnelNotFound) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) {
rw.WriteHeader(404)
payload := o.Payload
if err := producer.Produce(rw, payload); err != nil {
panic(err) // let the recovery middleware deal with this
}
}
// UntunnelInternalServerErrorCode is the HTTP code returned for type UntunnelInternalServerError // UntunnelInternalServerErrorCode is the HTTP code returned for type UntunnelInternalServerError
const UntunnelInternalServerErrorCode int = 500 const UntunnelInternalServerErrorCode int = 500

View File

@ -70,6 +70,10 @@ paths:
description: tunnel created description: tunnel created
schema: schema:
$ref: "#/definitions/tunnelResponse" $ref: "#/definitions/tunnelResponse"
401:
description: invalid environment identity
schema:
$ref: "#/definitions/errorMessage"
500: 500:
description: internal server error description: internal server error
schema: schema:
@ -89,6 +93,10 @@ paths:
responses: responses:
200: 200:
description: tunnel removed description: tunnel removed
404:
description: not found
schema:
$ref: "#/definitions/errorMessage"
500: 500:
description: internal server error description: internal server error
schema: schema: