diff --git a/controller/store/account.go b/controller/store/account.go index 4bc0150c..cc7094e7 100644 --- a/controller/store/account.go +++ b/controller/store/account.go @@ -13,24 +13,20 @@ type Account struct { } func (self *Store) CreateAccount(a *Account, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into accounts (email, password, token) values (?, ?, ?)") + stmt, err := tx.Prepare("insert into accounts (email, password, token) values ($1, $2, $3) returning id") if err != nil { return 0, errors.Wrap(err, "error preparing accounts insert statement") } - res, err := stmt.Exec(a.Email, a.Password, a.Token) - if err != nil { + var id int + if err := stmt.QueryRow(a.Email, a.Password, a.Token).Scan(&id); err != nil { return 0, errors.Wrap(err, "error executing accounts insert statement") } - id, err := res.LastInsertId() - if err != nil { - return 0, errors.Wrap(err, "error retrieving last accounts insert id") - } - return int(id), nil + return id, nil } func (self *Store) GetAccount(id int, tx *sqlx.Tx) (*Account, error) { a := &Account{} - if err := tx.QueryRowx("select * from accounts where id = ?", id).StructScan(a); err != nil { + if err := tx.QueryRowx("select * from accounts where id = $1", id).StructScan(a); err != nil { return nil, errors.Wrap(err, "error selecting account by id") } return a, nil @@ -38,7 +34,7 @@ func (self *Store) GetAccount(id int, tx *sqlx.Tx) (*Account, error) { func (self *Store) FindAccountWithEmail(email string, tx *sqlx.Tx) (*Account, error) { a := &Account{} - if err := tx.QueryRowx("select * from accounts where email = ?", email).StructScan(a); err != nil { + if err := tx.QueryRowx("select * from accounts where email = $1", email).StructScan(a); err != nil { return nil, errors.Wrap(err, "error selecting account by email") } return a, nil @@ -46,7 +42,7 @@ func (self *Store) FindAccountWithEmail(email string, tx *sqlx.Tx) (*Account, er func (self *Store) FindAccountWithToken(token string, tx *sqlx.Tx) (*Account, error) { a := &Account{} - if err := tx.QueryRowx("select * from accounts where token = ?", token).StructScan(a); err != nil { + if err := tx.QueryRowx("select * from accounts where token = $1", token).StructScan(a); err != nil { return nil, errors.Wrap(err, "error selecting account by token") } return a, nil diff --git a/controller/store/account_request.go b/controller/store/account_request.go index 45e8ce9b..e5df2c32 100644 --- a/controller/store/account_request.go +++ b/controller/store/account_request.go @@ -13,24 +13,20 @@ type AccountRequest struct { } func (self *Store) CreateAccountRequest(ar *AccountRequest, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into account_requests (token, email, source_address) values (?, ?, ?)") + stmt, err := tx.Prepare("insert into account_requests (token, email, source_address) values ($1, $2, $3) returning id") if err != nil { return 0, errors.Wrap(err, "error preparing account_requests insert statement") } - res, err := stmt.Exec(ar.Token, ar.Email, ar.SourceAddress) - if err != nil { + var id int + if err := stmt.QueryRow(ar.Token, ar.Email, ar.SourceAddress).Scan(&id); err != nil { return 0, errors.Wrap(err, "error executing account_requests insert statement") } - id, err := res.LastInsertId() - if err != nil { - return 0, errors.Wrap(err, "error retrieving last account_requests insert id") - } - return int(id), nil + return id, nil } func (self *Store) GetAccountRequest(id int, tx *sqlx.Tx) (*AccountRequest, error) { ar := &AccountRequest{} - if err := tx.QueryRowx("select * from account_requests where id = ?", id).StructScan(ar); err != nil { + if err := tx.QueryRowx("select * from account_requests where id = $1", id).StructScan(ar); err != nil { return nil, errors.Wrap(err, "error selecting account_request by id") } return ar, nil @@ -38,7 +34,7 @@ func (self *Store) GetAccountRequest(id int, tx *sqlx.Tx) (*AccountRequest, erro func (self *Store) FindAccountRequestWithToken(token string, tx *sqlx.Tx) (*AccountRequest, error) { ar := &AccountRequest{} - if err := tx.QueryRowx("select * from account_requests where token = ?", token).StructScan(ar); err != nil { + if err := tx.QueryRowx("select * from account_requests where token = $1", token).StructScan(ar); err != nil { return nil, errors.Wrap(err, "error selecting account_request by token") } return ar, nil @@ -46,14 +42,14 @@ func (self *Store) FindAccountRequestWithToken(token string, tx *sqlx.Tx) (*Acco func (self *Store) FindAccountRequestWithEmail(email string, tx *sqlx.Tx) (*AccountRequest, error) { ar := &AccountRequest{} - if err := tx.QueryRowx("select * from account_requests where email = ?", email).StructScan(ar); err != nil { + if err := tx.QueryRowx("select * from account_requests where email = $1", email).StructScan(ar); err != nil { return nil, errors.Wrap(err, "error selecting account_request by email") } return ar, nil } func (self *Store) DeleteAccountRequest(id int, tx *sqlx.Tx) error { - stmt, err := tx.Prepare("delete from account_requests where id = ?") + stmt, err := tx.Prepare("delete from account_requests where id = $1") if err != nil { return errors.Wrap(err, "error preparing account_requests delete statement") } diff --git a/controller/store/environment.go b/controller/store/environment.go index 216f8dc0..7d12074b 100644 --- a/controller/store/environment.go +++ b/controller/store/environment.go @@ -15,31 +15,27 @@ type Environment struct { } func (self *Store) CreateEnvironment(accountId int, i *Environment, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into environments (account_id, description, host, address, z_id) values (?, ?, ?, ?, ?)") + stmt, err := tx.Prepare("insert into environments (account_id, description, host, address, z_id) values ($1, $2, $3, $4, $5) returning id") if err != nil { return 0, errors.Wrap(err, "error preparing environments insert statement") } - res, err := stmt.Exec(accountId, i.Description, i.Host, i.Address, i.ZId) - if err != nil { + var id int + if err := stmt.QueryRow(accountId, i.Description, i.Host, i.Address, i.ZId).Scan(&id); err != nil { return 0, errors.Wrap(err, "error executing environments insert statement") } - id, err := res.LastInsertId() - if err != nil { - return 0, errors.Wrap(err, "error retrieving last environments insert id") - } - return int(id), nil + return id, nil } func (self *Store) GetEnvironment(id int, tx *sqlx.Tx) (*Environment, error) { i := &Environment{} - if err := tx.QueryRowx("select * from environments where id = ?", id).StructScan(i); err != nil { + if err := tx.QueryRowx("select * from environments where id = $1", id).StructScan(i); err != nil { return nil, errors.Wrap(err, "error selecting environment by id") } return i, nil } func (self *Store) FindEnvironmentsForAccount(accountId int, tx *sqlx.Tx) ([]*Environment, error) { - rows, err := tx.Queryx("select environments.* from environments where account_id = ?", accountId) + rows, err := tx.Queryx("select environments.* from environments where account_id = $1", accountId) if err != nil { return nil, errors.Wrap(err, "error selecting environments by account id") } @@ -55,7 +51,7 @@ func (self *Store) FindEnvironmentsForAccount(accountId int, tx *sqlx.Tx) ([]*En } func (self *Store) DeleteEnvironment(id int, tx *sqlx.Tx) error { - stmt, err := tx.Prepare("delete from environments where id = ?") + stmt, err := tx.Prepare("delete from environments where id = $1") if err != nil { return errors.Wrap(err, "error preparing environments delete statement") } diff --git a/controller/store/service.go b/controller/store/service.go index fa3e6909..96a12ddf 100644 --- a/controller/store/service.go +++ b/controller/store/service.go @@ -15,24 +15,20 @@ type Service struct { } func (self *Store) CreateService(envId int, svc *Service, tx *sqlx.Tx) (int, error) { - stmt, err := tx.Prepare("insert into services (environment_id, z_id, name, frontend, backend) values (?, ?, ?, ?, ?)") + stmt, err := tx.Prepare("insert into services (environment_id, z_id, name, frontend, backend) values ($1, $2, $3, $4, $5) returning id") if err != nil { return 0, errors.Wrap(err, "error preparing services insert statement") } - res, err := stmt.Exec(envId, svc.ZId, svc.Name, svc.Frontend, svc.Backend) - if err != nil { + var id int + if err := stmt.QueryRow(envId, svc.ZId, svc.Name, svc.Frontend, svc.Backend).Scan(&id); err != nil { return 0, errors.Wrap(err, "error executing services insert statement") } - id, err := res.LastInsertId() - if err != nil { - return 0, errors.Wrap(err, "error retrieving last services insert id") - } - return int(id), nil + return id, nil } func (self *Store) GetService(id int, tx *sqlx.Tx) (*Service, error) { svc := &Service{} - if err := tx.QueryRowx("select * from services where id = ?", id).StructScan(svc); err != nil { + if err := tx.QueryRowx("select * from services where id = $1", id).StructScan(svc); err != nil { return nil, errors.Wrap(err, "error selecting service by id") } return svc, nil @@ -55,7 +51,7 @@ func (self *Store) GetAllServices(tx *sqlx.Tx) ([]*Service, error) { } func (self *Store) FindServicesForEnvironment(envId int, tx *sqlx.Tx) ([]*Service, error) { - rows, err := tx.Queryx("select services.* from services where environment_id = ?", envId) + rows, err := tx.Queryx("select services.* from services where environment_id = $1", envId) if err != nil { return nil, errors.Wrap(err, "error selecting services by environment id") } @@ -71,7 +67,7 @@ func (self *Store) FindServicesForEnvironment(envId int, tx *sqlx.Tx) ([]*Servic } func (self *Store) UpdateService(svc *Service, tx *sqlx.Tx) error { - sql := "update services set z_id = ?, name = ?, frontend = ?, backend = ?, updated_at = strftime('%Y-%m-%d %H:%M:%f', 'now') where id = ?" + sql := "update services set z_id = $1, name = $2, frontend = $3, backend = $4, updated_at = strftime('%Y-%m-%d %H:%M:%f', 'now') where id = $5" stmt, err := tx.Prepare(sql) if err != nil { return errors.Wrap(err, "error preparing services update statement") @@ -84,7 +80,7 @@ func (self *Store) UpdateService(svc *Service, tx *sqlx.Tx) error { } func (self *Store) DeleteService(id int, tx *sqlx.Tx) error { - stmt, err := tx.Prepare("delete from services where id = ?") + stmt, err := tx.Prepare("delete from services where id = $1") if err != nil { return errors.Wrap(err, "error preparing services delete statement") } diff --git a/controller/store/sql/postgresql/000_base.sql b/controller/store/sql/postgresql/000_base.sql index 527401bc..39b3cfcd 100644 --- a/controller/store/sql/postgresql/000_base.sql +++ b/controller/store/sql/postgresql/000_base.sql @@ -4,7 +4,7 @@ -- accounts -- create table accounts ( - id integer primary key, + id serial primary key, email varchar(1024) not null unique, password char(128) not null, token varchar(32) not null unique, @@ -20,7 +20,7 @@ create table accounts ( -- account_requests -- create table account_requests ( - id integer primary key, + id serial primary key, token varchar(32) not null unique, email varchar(1024) not null unique, source_address varchar(64) not null, @@ -32,7 +32,7 @@ create table account_requests ( -- environments -- create table environments ( - id integer primary key, + id serial primary key, account_id integer constraint fk_accounts_identities references accounts on delete cascade, description text, host varchar(256), @@ -48,7 +48,7 @@ create table environments ( -- services -- create table services ( - id integer primary key, + id serial primary key, environment_id integer constraint fk_environments_services references environments on delete cascade, z_id varchar(32) not null unique, name varchar(32) not null unique, diff --git a/controller/store/store.go b/controller/store/store.go index fbe2439c..6043dda5 100644 --- a/controller/store/store.go +++ b/controller/store/store.go @@ -4,7 +4,9 @@ import ( "fmt" "github.com/iancoleman/strcase" "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" + postgresql_schema "github.com/openziti-test-kitchen/zrok/controller/store/sql/postgresql" sqlite3_schema "github.com/openziti-test-kitchen/zrok/controller/store/sql/sqlite3" "github.com/pkg/errors" migrate "github.com/rubenv/sql-migrate" @@ -20,6 +22,7 @@ type Model struct { type Config struct { Path string + Type string } type Store struct { @@ -28,15 +31,27 @@ type Store struct { } func Open(cfg *Config) (*Store, error) { - dbx, err := sqlx.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", cfg.Path)) - if err != nil { - return nil, errors.Wrapf(err, "error opening database '%v'", cfg.Path) + var dbx *sqlx.DB + var err error + + switch cfg.Type { + case "sqlite3": + dbx, err = sqlx.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", cfg.Path)) + if err != nil { + return nil, errors.Wrapf(err, "error opening database '%v'", cfg.Path) + } + dbx.DB.SetMaxOpenConns(1) + + case "postgres": + dbx, err = sqlx.Connect("postgres", cfg.Path) + if err != nil { + return nil, errors.Wrapf(err, "error opening database '%v'", cfg.Path) + } } - dbx.DB.SetMaxOpenConns(1) logrus.Infof("opened database '%v'", cfg.Path) dbx.MapperFunc(strcase.ToSnake) store := &Store{cfg: cfg, db: dbx} - if err := store.migrate(); err != nil { + if err := store.migrate(cfg); err != nil { return nil, errors.Wrapf(err, "error migrating database '%v'", cfg.Path) } return store, nil @@ -50,16 +65,31 @@ func (self *Store) Close() error { return self.db.Close() } -func (self *Store) migrate() error { - migrations := &migrate.EmbedFileSystemMigrationSource{ - FileSystem: sqlite3_schema.FS, - Root: "/", +func (self *Store) migrate(cfg *Config) error { + switch cfg.Type { + case "sqlite3": + migrations := &migrate.EmbedFileSystemMigrationSource{ + FileSystem: sqlite3_schema.FS, + Root: "/", + } + migrate.SetTable("migrations") + n, err := migrate.Exec(self.db.DB, "sqlite3", migrations, migrate.Up) + if err != nil { + return errors.Wrap(err, "error running migrations") + } + logrus.Infof("applied %d migrations", n) + + case "postgres": + migrations := &migrate.EmbedFileSystemMigrationSource{ + FileSystem: postgresql_schema.FS, + Root: "/", + } + migrate.SetTable("migrations") + n, err := migrate.Exec(self.db.DB, "postgres", migrations, migrate.Up) + if err != nil { + return errors.Wrap(err, "error running migrations") + } + logrus.Infof("applied %d migrations", n) } - migrate.SetTable("migrations") - n, err := migrate.Exec(self.db.DB, "sqlite3", migrations, migrate.Up) - if err != nil { - return errors.Wrap(err, "error running migrations") - } - logrus.Infof("applied %d migrations", n) return nil } diff --git a/go.mod b/go.mod index e98c0aba..83c1762b 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/jaevor/go-nanoid v1.3.0 github.com/jessevdk/go-flags v1.5.0 github.com/jmoiron/sqlx v1.3.5 + github.com/lib/pq v1.10.0 github.com/mattn/go-sqlite3 v1.14.14 github.com/michaelquigley/cf v0.0.12 github.com/michaelquigley/pfxlog v0.6.9