diff --git a/controller/store/identity.go b/controller/store/identity.go index 868b2e4e..91e6d5d0 100644 --- a/controller/store/identity.go +++ b/controller/store/identity.go @@ -9,10 +9,11 @@ type Identity struct { Model AccountId int ZitiId string + Active bool } func (self *Store) CreateIdentity(accountId int, i *Identity, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into identities (account_id, ziti_id) values (?, ?)") + stmt, err := tx.Prepare("insert into identities (account_id, ziti_id, active) values (?, ?, true)") if err != nil { return 0, errors.Wrap(err, "error preparing identities insert statement") } diff --git a/controller/store/service.go b/controller/store/service.go index 0c584b7d..df736799 100644 --- a/controller/store/service.go +++ b/controller/store/service.go @@ -10,10 +10,11 @@ type Service struct { AccountId int ZitiId string Endpoint string + Active bool } func (self *Store) CreateService(accountId int, svc *Service, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into services (account_id, ziti_id, endpoint) values (?, ?, ?)") + stmt, err := tx.Prepare("insert into services (account_id, ziti_id, endpoint, active) values (?, ?, ?, true)") if err != nil { return 0, errors.Wrap(err, "error preparing services insert statement") } @@ -52,6 +53,18 @@ func (self *Store) FindServicesForAccount(accountId int, tx *sqlx.Tx) ([]*Servic return svcs, nil } +func (self *Store) DeactivateService(id int, tx *sqlx.Tx) error { + stmt, err := tx.Prepare("update services set active=false where id = ?") + if err != nil { + return errors.Wrap(err, "error preparing services deactivate statement") + } + _, err = stmt.Exec(id) + if err != nil { + return errors.Wrap(err, "error executing services deactivate statement") + } + return nil +} + func (self *Store) DeleteService(id int, tx *sqlx.Tx) error { stmt, err := tx.Prepare("delete from services where id = ?") if err != nil { diff --git a/controller/store/sql/000_base.sql b/controller/store/sql/000_base.sql index 9b733942..af4e92e8 100644 --- a/controller/store/sql/000_base.sql +++ b/controller/store/sql/000_base.sql @@ -23,8 +23,9 @@ create table identities ( id integer primary key, account_id integer constraint fk_accounts_identities references accounts on delete cascade, ziti_id string not null unique, - created_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')), - updated_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')), + active boolean not null, + created_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')), + updated_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')), constraint chk_ziti_id check (ziti_id <> '') ); @@ -37,6 +38,7 @@ create table services ( account_id integer constraint fk_accounts_services references accounts on delete cascade, ziti_id string not null unique, endpoint string, + active boolean not null, created_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')), updated_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')), diff --git a/controller/tunnel.go b/controller/tunnel.go index 93362f65..8279568f 100644 --- a/controller/tunnel.go +++ b/controller/tunnel.go @@ -25,6 +25,7 @@ func tunnelHandler(params tunnel.TunnelParams, principal *rest_model_zrok.Princi logrus.Errorf("error starting transaction: %v", err) return tunnel.NewTunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } + defer func() { _ = tx.Rollback() }() edge, err := edgeClient() if err != nil { diff --git a/controller/untunnel.go b/controller/untunnel.go index 3c5a8d36..eef438b8 100644 --- a/controller/untunnel.go +++ b/controller/untunnel.go @@ -17,35 +17,62 @@ import ( func untunnelHandler(params tunnel.UntunnelParams, principal *rest_model_zrok.Principal) middleware.Responder { logrus.Infof("untunneling for '%v' (%v)", principal.Username, principal.Token) + edge, err := edgeClient() if err != nil { logrus.Error(err) - return tunnel.NewUntunnelInternalServerError() + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } svcName := params.Body.Service if err := deleteEdgeRouterPolicy(svcName, edge); err != nil { logrus.Error(err) - return tunnel.NewUntunnelInternalServerError() + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } if err := deleteServiceEdgeRouterPolicy(svcName, edge); err != nil { logrus.Error(err) - return tunnel.NewUntunnelInternalServerError() + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } if err := deleteServicePolicyDial(svcName, edge); err != nil { logrus.Error(err) - return tunnel.NewUntunnelInternalServerError() + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } if err := deleteServicePolicyBind(svcName, edge); err != nil { logrus.Error(err) - return tunnel.NewUntunnelInternalServerError() + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } - if err := deleteService(svcName, edge); err != nil { + svcId, err := deleteService(svcName, edge) + if err != nil { logrus.Error(err) return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) } logrus.Infof("deallocated service '%v'", svcName) + tx, err := str.Begin() + if err != nil { + logrus.Errorf("error starting transaction: %v", err) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + defer func() { _ = tx.Rollback() }() + svcs, err := str.FindServicesForAccount(int(principal.ID), tx) + if err != nil { + logrus.Errorf("error finding services for account: %v", err) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + for _, svc := range svcs { + if svc.ZitiId == svcId { + if err := str.DeactivateService(svc.Id, tx); err != nil { + logrus.Errorf("error deactivating service '%v': %v", svcId, err) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + if err := tx.Commit(); err != nil { + logrus.Errorf("error committing: %v", err) + return tunnel.NewUntunnelInternalServerError().WithPayload(rest_model_zrok.ErrorMessage(err.Error())) + } + logrus.Infof("deactivated service '%v'", svcId) + } + } + return tunnel.NewUntunnelOK() } @@ -155,7 +182,7 @@ func deleteServicePolicy(filter string, edge *rest_management_api_client.ZitiEdg return nil } -func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) error { +func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeManagement) (string, error) { filter := fmt.Sprintf("name=\"%v\"", svcName) limit := int64(1) offset := int64(0) @@ -168,7 +195,7 @@ func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeMana listReq.SetTimeout(30 * time.Second) listResp, err := edge.Service.ListServices(listReq, nil) if err != nil { - return err + return "", err } if len(listResp.Payload.Data) == 1 { svcId := *(listResp.Payload.Data[0].ID) @@ -179,11 +206,12 @@ func deleteService(svcName string, edge *rest_management_api_client.ZitiEdgeMana req.SetTimeout(30 * time.Second) _, err := edge.Service.DeleteService(req, nil) if err != nil { - return err + return "", err } logrus.Infof("deleted service '%v'", svcId) + return svcId, nil } else { logrus.Infof("did not find a service") } - return nil + return "", nil }