diff --git a/controller/enable.go b/controller/enable.go index 5d9d688f..59ceaf47 100644 --- a/controller/enable.go +++ b/controller/enable.go @@ -19,6 +19,13 @@ import ( ) func enableHandler(_ identity.EnableParams, principal *rest_model_zrok.Principal) middleware.Responder { + // start transaction early; if it fails, don't bother creating ziti resources + tx, err := str.Begin() + if err != nil { + logrus.Errorf("error starting transaction: %v", err) + return identity.NewCreateAccountInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + client, err := edgeClient() if err != nil { logrus.Errorf("error getting edge client: %v", err) @@ -35,11 +42,6 @@ func enableHandler(_ identity.EnableParams, principal *rest_model_zrok.Principal return identity.NewEnableInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } - tx, err := str.Begin() - if err != nil { - logrus.Errorf("error starting transaction: %v", err) - return identity.NewCreateAccountInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) - } iid, err := str.CreateIdentity(int(principal.ID), &store.Identity{ZitiId: ident.Payload.Data.ID}, tx) if err != nil { logrus.Errorf("error storing created identity: %v", err) diff --git a/controller/tunnel.go b/controller/tunnel.go index 960af0d1..93362f65 100644 --- a/controller/tunnel.go +++ b/controller/tunnel.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/go-openapi/runtime/middleware" + "github.com/openziti-test-kitchen/zrok/controller/store" "github.com/openziti-test-kitchen/zrok/rest_model_zrok" "github.com/openziti-test-kitchen/zrok/rest_server_zrok/operations/tunnel" "github.com/openziti/edge/rest_management_api_client" @@ -18,45 +19,63 @@ import ( func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Principal) middleware.Responder { logrus.Infof("tunneling for '%v' (%v)", principal.Username, principal.Token) + + tx, err := str.Begin() + if err != nil { + logrus.Errorf("error starting transaction: %v", err) + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + edge, err := edgeClient() if err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } svcName, err := randomId() if err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } svcId, err := createService(svcName, edge) if err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } envId := params.Body.Identity if err := createServicePolicyBind(svcName, svcId, envId, edge); err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } if err := createServicePolicyDial(svcName, svcId, edge); err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } if err := createServiceEdgeRouterPolicy(svcName, svcId, edge); err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } if err := createEdgeRouterPolicy(svcName, envId, edge); err != nil { logrus.Error(err) - return tunnel.NewTunnelInternalServerError() + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } logrus.Infof("allocated service '%v'", svcName) - resp := tunnel.NewTunnelCreated().WithPayload(&rest_model_zrok.TunnelResponse{ + sid, err := str.CreateService(int(principal.ID), &store.Service{ZitiId: svcId, Endpoint: params.Body.Endpoint}, tx) + if err != nil { + logrus.Errorf("error creating service record: %v", err) + _ = tx.Rollback() + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + if err := tx.Commit(); err != nil { + logrus.Errorf("error committing service record: %v", err) + return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + logrus.Infof("recorded service '%v' with id '%v' for '%v'", svcId, sid, principal.Username) + + return tunnel.NewTunnelCreated().WithPayload(&rest_model_zrok.TunnelResponse{ Service: svcName, }) - return resp } func createService(name string, edge *rest_management_api_client.ZitiEdgeManagement) (serviceId string, err error) {