diff --git a/controller/tunnel.go b/controller/tunnel.go index 955b4a73..179d2794 100644 --- a/controller/tunnel.go +++ b/controller/tunnel.go @@ -27,6 +27,25 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi } defer func() { _ = tx.Rollback() }() + envId := params.Body.Identity + if is, err := str.FindIdentitiesForAccount(int(principal.ID), tx); err == nil { + found := false + for _, i := range is { + if i.ZitiId == envId { + logrus.Infof("found identity '%v' for user '%v'", envId, principal.Username) + found = true + break + } + } + if !found { + logrus.Errorf("identity '%v' not found for user '%v'", envId, principal.Username) + return tunnel.NewTunnelUnauthorized().WithPayload("bad environment identity") + } + } else { + logrus.Errorf("error finding identities for account '%v'", principal.Username) + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + edge, err := edgeClient() if err != nil { logrus.Error(err) @@ -42,7 +61,6 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi logrus.Error(err) return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } - envId := params.Body.Identity if err := createServicePolicyBind(svcName, svcId, envId, edge); err != nil { logrus.Error(err) return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) diff --git a/controller/untunnel.go b/controller/untunnel.go index d7893552..8c54ba48 100644 --- a/controller/untunnel.go +++ b/controller/untunnel.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/go-openapi/runtime/middleware" + "github.com/openziti-test-kitchen/zrok/controller/store" "github.com/openziti-test-kitchen/zrok/rest_model_zrok" "github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/tunnel" "github.com/openziti/edge/rest_management_api_client" @@ -11,6 +12,7 @@ import ( "github.com/openziti/edge/rest_management_api_client/service" "github.com/openziti/edge/rest_management_api_client/service_edge_router_policy" "github.com/openziti/edge/rest_management_api_client/service_policy" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "time" ) @@ -18,12 +20,42 @@ import ( func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Principal) middleware.Responder { logrus.Infof("untunneling for '%v' (%v)", principal.Username, principal.Token) + tx, err := str.Begin() + if err != nil { + logrus.Errorf("error starting transaction: %v", err) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + defer func() { _ = tx.Rollback() }() + edge, err := edgeClient() if err != nil { logrus.Error(err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } svcName := params.Body.Service + svcId, err := findServiceId(svcName, edge) + if err != nil { + logrus.Error(err) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + var ssvc *store.Service + if svcs, err := str.FindServicesForAccount(int(principal.ID), tx); err == nil { + for _, svc := range svcs { + if svc.ZitiId == svcId { + ssvc = svc + break + } + } + if ssvc == nil { + err := errors.Errorf("service with id '%v' not found for '%v'", svcId, principal.Username) + logrus.Error(err) + return tunnel.NewUntunnelNotFound().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + } else { + logrus.Errorf("error finding services for account '%v'", principal.Username) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + if err := deleteEdgeRouterPolicy(svcName, edge); err != nil { logrus.Error(err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) @@ -40,47 +72,47 @@ func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Pr logrus.Error(err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } - svcId, err := deleteService(svcName, edge) - if err != nil { + if err := deleteService(svcId, edge); err != nil { logrus.Error(err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } logrus.Infof("deallocated service '%v'", svcName) - tx, err := str.Begin() - if err != nil { - logrus.Errorf("error starting transaction: %v", err) + ssvc.Active = false + if err := str.UpdateService(ssvc, tx); err != nil { + logrus.Errorf("error deactivating service '%v': %v", svcId, err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } - defer func() { _ = tx.Rollback() }() - svcs, err := str.FindServicesForAccount(int(principal.ID), tx) - if err != nil { - logrus.Errorf("error finding services for account: %v", err) + if err := tx.Commit(); err != nil { + logrus.Errorf("error committing: %v", err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } - changed := false - for _, svc := range svcs { - if svc.ZitiId == svcId { - svc.Active = false - if err := str.UpdateService(svc, tx); err != nil { - logrus.Errorf("error deactivating service '%v': %v", svcId, err) - return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) - } - changed = true - logrus.Infof("deactivated service '%v'", svcId) - } - } - if changed { - if err := tx.Commit(); err != nil { - logrus.Errorf("error committing: %v", err) - return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) - } - } return tunnel.NewUntunnelOK() } +func findServiceId(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) { + filter := fmt.Sprintf("name=\"%v\"", svcName) + limit := int64(1) + offset := int64(0) + listReq := &service.ListServicesParams{ + Filter: &filter, + Limit: &limit, + Offset: &offset, + Context: context.Background(), + } + listReq.SetTimeout(30 * time.Second) + listResp, err := edge.Service.ListServices(listReq, nil) + if err != nil { + return "", err + } + if len(listResp.Payload.Data) == 1 { + return *(listResp.Payload.Data[0].ID), nil + } + return "", errors.Errorf("service '%v' not found", svcName) +} + func deleteEdgeRouterPolicy(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) error { filter := fmt.Sprintf("name=\"%v\"", svcName) limit := int64(1) @@ -187,36 +219,16 @@ func deleteServicePolicy(filter string, edge *rest_management_api_client.ZitiEdg return nil } -func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) { - filter := fmt.Sprintf("name=\"%v\"", svcName) - limit := int64(1) - offset := int64(0) - listReq := &service.ListServicesParams{ - Filter: &filter, - Limit: &limit, - Offset: &offset, +func deleteService(svcId string, edge *rest_management_api_client.ZitiEdgeManagement) error { + req := &service.DeleteServiceParams{ + ID: svcId, Context: context.Background(), } - listReq.SetTimeout(30 * time.Second) - listResp, err := edge.Service.ListServices(listReq, nil) + req.SetTimeout(30 * time.Second) + _, err := edge.Service.DeleteService(req, nil) if err != nil { - return "", err + return err } - if len(listResp.Payload.Data) == 1 { - svcId := *(listResp.Payload.Data[0].ID) - req := &service.DeleteServiceParams{ - ID: svcId, - Context: context.Background(), - } - req.SetTimeout(30 * time.Second) - _, err := edge.Service.DeleteService(req, nil) - if err != nil { - return "", err - } - logrus.Infof("deleted service '%v'", svcId) - return svcId, nil - } else { - logrus.Infof("did not find a service") - } - return "", nil + logrus.Infof("deleted service '%v'", svcId) + return nil } diff --git a/rest_client_zrok/tunnel/tunnel_responses.go b/rest_client_zrok/tunnel/tunnel_responses.go index 2f9d0f35..c73137db 100644 --- a/rest_client_zrok/tunnel/tunnel_responses.go +++ b/rest_client_zrok/tunnel/tunnel_responses.go @@ -29,6 +29,12 @@ func (o *TunnelReader) ReadResponse(response runtime.ClientResponse, consumer ru return nil, err } return result, nil + case 401: + result := NewTunnelUnauthorized() + if err := result.readResponse(response, consumer, o.formats); err != nil { + return nil, err + } + return nil, result case 500: result := NewTunnelInternalServerError() if err := result.readResponse(response, consumer, o.formats); err != nil { @@ -72,6 +78,36 @@ func (o *TunnelCreated) readResponse(response runtime.ClientResponse, consumer r return nil } +// NewTunnelUnauthorized creates a TunnelUnauthorized with default headers values +func NewTunnelUnauthorized() *TunnelUnauthorized { + return &TunnelUnauthorized{} +} + +/* TunnelUnauthorized describes a response with status code 401, with default header values. + +invalid environment identity +*/ +type TunnelUnauthorized struct { + Payload rest_model_zrok.ErrorMessage +} + +func (o *TunnelUnauthorized) Error() string { + return fmt.Sprintf("[POST /tunnel][%d] tunnelUnauthorized %+v", 401, o.Payload) +} +func (o *TunnelUnauthorized) GetPayload() rest_model_zrok.ErrorMessage { + return o.Payload +} + +func (o *TunnelUnauthorized) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error { + + // response payload + if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF { + return err + } + + return nil +} + // NewTunnelInternalServerError creates a TunnelInternalServerError with default headers values func NewTunnelInternalServerError() *TunnelInternalServerError { return &TunnelInternalServerError{} diff --git a/rest_client_zrok/tunnel/untunnel_responses.go b/rest_client_zrok/tunnel/untunnel_responses.go index 388758f7..e4a2eb34 100644 --- a/rest_client_zrok/tunnel/untunnel_responses.go +++ b/rest_client_zrok/tunnel/untunnel_responses.go @@ -29,6 +29,12 @@ func (o *UntunnelReader) ReadResponse(response runtime.ClientResponse, consumer return nil, err } return result, nil + case 404: + result := NewUntunnelNotFound() + if err := result.readResponse(response, consumer, o.formats); err != nil { + return nil, err + } + return nil, result case 500: result := NewUntunnelInternalServerError() if err := result.readResponse(response, consumer, o.formats); err != nil { @@ -61,6 +67,36 @@ func (o *UntunnelOK) readResponse(response runtime.ClientResponse, consumer runt return nil } +// NewUntunnelNotFound creates a UntunnelNotFound with default headers values +func NewUntunnelNotFound() *UntunnelNotFound { + return &UntunnelNotFound{} +} + +/* UntunnelNotFound describes a response with status code 404, with default header values. + +not found +*/ +type UntunnelNotFound struct { + Payload rest_model_zrok.ErrorMessage +} + +func (o *UntunnelNotFound) Error() string { + return fmt.Sprintf("[DELETE /untunnel][%d] untunnelNotFound %+v", 404, o.Payload) +} +func (o *UntunnelNotFound) GetPayload() rest_model_zrok.ErrorMessage { + return o.Payload +} + +func (o *UntunnelNotFound) readResponse(response runtime.ClientResponse, consumer runtime.Consumer, formats strfmt.Registry) error { + + // response payload + if err := consumer.Consume(response.Body(), &o.Payload); err != nil && err != io.EOF { + return err + } + + return nil +} + // NewUntunnelInternalServerError creates a UntunnelInternalServerError with default headers values func NewUntunnelInternalServerError() *UntunnelInternalServerError { return &UntunnelInternalServerError{} diff --git a/rest_server_zrok/embedded_spec.go b/rest_server_zrok/embedded_spec.go index 2f2e896f..cbb0b802 100644 --- a/rest_server_zrok/embedded_spec.go +++ b/rest_server_zrok/embedded_spec.go @@ -131,6 +131,12 @@ func init() { "$ref": "#/definitions/tunnelResponse" } }, + "401": { + "description": "invalid environment identity", + "schema": { + "$ref": "#/definitions/errorMessage" + } + }, "500": { "description": "internal server error", "schema": { @@ -164,6 +170,12 @@ func init() { "200": { "description": "tunnel removed" }, + "404": { + "description": "not found", + "schema": { + "$ref": "#/definitions/errorMessage" + } + }, "500": { "description": "internal server error", "schema": { @@ -391,6 +403,12 @@ func init() { "$ref": "#/definitions/tunnelResponse" } }, + "401": { + "description": "invalid environment identity", + "schema": { + "$ref": "#/definitions/errorMessage" + } + }, "500": { "description": "internal server error", "schema": { @@ -424,6 +442,12 @@ func init() { "200": { "description": "tunnel removed" }, + "404": { + "description": "not found", + "schema": { + "$ref": "#/definitions/errorMessage" + } + }, "500": { "description": "internal server error", "schema": { diff --git a/rest_server_zrok/operations/tunnel/tunnel_responses.go b/rest_server_zrok/operations/tunnel/tunnel_responses.go index a30bb4ea..f3dc3b5e 100644 --- a/rest_server_zrok/operations/tunnel/tunnel_responses.go +++ b/rest_server_zrok/operations/tunnel/tunnel_responses.go @@ -57,6 +57,48 @@ func (o *TunnelCreated) WriteResponse(rw http.ResponseWriter, producer runtime.P } } +// TunnelUnauthorizedCode is the HTTP code returned for type TunnelUnauthorized +const TunnelUnauthorizedCode int = 401 + +/*TunnelUnauthorized invalid environment identity + +swagger:response tunnelUnauthorized +*/ +type TunnelUnauthorized struct { + + /* + In: Body + */ + Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"` +} + +// NewTunnelUnauthorized creates TunnelUnauthorized with default headers values +func NewTunnelUnauthorized() *TunnelUnauthorized { + + return &TunnelUnauthorized{} +} + +// WithPayload adds the payload to the tunnel unauthorized response +func (o *TunnelUnauthorized) WithPayload(payload rest_model_zrok.ErrorMessage) *TunnelUnauthorized { + o.Payload = payload + return o +} + +// SetPayload sets the payload to the tunnel unauthorized response +func (o *TunnelUnauthorized) SetPayload(payload rest_model_zrok.ErrorMessage) { + o.Payload = payload +} + +// WriteResponse to the client +func (o *TunnelUnauthorized) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.WriteHeader(401) + payload := o.Payload + if err := producer.Produce(rw, payload); err != nil { + panic(err) // let the recovery middleware deal with this + } +} + // TunnelInternalServerErrorCode is the HTTP code returned for type TunnelInternalServerError const TunnelInternalServerErrorCode int = 500 diff --git a/rest_server_zrok/operations/tunnel/untunnel_responses.go b/rest_server_zrok/operations/tunnel/untunnel_responses.go index 4aafc655..796c9178 100644 --- a/rest_server_zrok/operations/tunnel/untunnel_responses.go +++ b/rest_server_zrok/operations/tunnel/untunnel_responses.go @@ -37,6 +37,48 @@ func (o *UntunnelOK) WriteResponse(rw http.ResponseWriter, producer runtime.Prod rw.WriteHeader(200) } +// UntunnelNotFoundCode is the HTTP code returned for type UntunnelNotFound +const UntunnelNotFoundCode int = 404 + +/*UntunnelNotFound not found + +swagger:response untunnelNotFound +*/ +type UntunnelNotFound struct { + + /* + In: Body + */ + Payload rest_model_zrok.ErrorMessage `json:"body,omitempty"` +} + +// NewUntunnelNotFound creates UntunnelNotFound with default headers values +func NewUntunnelNotFound() *UntunnelNotFound { + + return &UntunnelNotFound{} +} + +// WithPayload adds the payload to the untunnel not found response +func (o *UntunnelNotFound) WithPayload(payload rest_model_zrok.ErrorMessage) *UntunnelNotFound { + o.Payload = payload + return o +} + +// SetPayload sets the payload to the untunnel not found response +func (o *UntunnelNotFound) SetPayload(payload rest_model_zrok.ErrorMessage) { + o.Payload = payload +} + +// WriteResponse to the client +func (o *UntunnelNotFound) WriteResponse(rw http.ResponseWriter, producer runtime.Producer) { + + rw.WriteHeader(404) + payload := o.Payload + if err := producer.Produce(rw, payload); err != nil { + panic(err) // let the recovery middleware deal with this + } +} + // UntunnelInternalServerErrorCode is the HTTP code returned for type UntunnelInternalServerError const UntunnelInternalServerErrorCode int = 500 diff --git a/specs/zrok.yml b/specs/zrok.yml index 587781ae..cb7eb2ec 100644 --- a/specs/zrok.yml +++ b/specs/zrok.yml @@ -70,6 +70,10 @@ paths: description: tunnel created schema: $ref: "#/definitions/tunnelResponse" + 401: + description: invalid environment identity + schema: + $ref: "#/definitions/errorMessage" 500: description: internal server error schema: @@ -89,6 +93,10 @@ paths: responses: 200: description: tunnel removed + 404: + description: not found + schema: + $ref: "#/definitions/errorMessage" 500: description: internal server error schema: