changes to support the agreement between sqlite and postgres (#46)

This commit is contained in:
Michael Quigley 2022-10-21 09:31:12 -04:00
parent d479ff8609
commit 014da707d7
No known key found for this signature in database
GPG Key ID: 9B60314A9DD20A62
7 changed files with 80 additions and 65 deletions

View File

@ -13,24 +13,20 @@ type Account struct {
} }
func (self *Store) CreateAccount(a *Account, tx *sqlx.Tx) (int, error) { func (self *Store) CreateAccount(a *Account, tx *sqlx.Tx) (int, error) {
stmt, err := tx.Prepare("insert into accounts (email, password, token) values (?, ?, ?)") stmt, err := tx.Prepare("insert into accounts (email, password, token) values ($1, $2, $3) returning id")
if err != nil { if err != nil {
return 0, errors.Wrap(err, "error preparing accounts insert statement") return 0, errors.Wrap(err, "error preparing accounts insert statement")
} }
res, err := stmt.Exec(a.Email, a.Password, a.Token) var id int
if err != nil { if err := stmt.QueryRow(a.Email, a.Password, a.Token).Scan(&id); err != nil {
return 0, errors.Wrap(err, "error executing accounts insert statement") return 0, errors.Wrap(err, "error executing accounts insert statement")
} }
id, err := res.LastInsertId() return id, nil
if err != nil {
return 0, errors.Wrap(err, "error retrieving last accounts insert id")
}
return int(id), nil
} }
func (self *Store) GetAccount(id int, tx *sqlx.Tx) (*Account, error) { func (self *Store) GetAccount(id int, tx *sqlx.Tx) (*Account, error) {
a := &Account{} a := &Account{}
if err := tx.QueryRowx("select * from accounts where id = ?", id).StructScan(a); err != nil { if err := tx.QueryRowx("select * from accounts where id = $1", id).StructScan(a); err != nil {
return nil, errors.Wrap(err, "error selecting account by id") return nil, errors.Wrap(err, "error selecting account by id")
} }
return a, nil return a, nil
@ -38,7 +34,7 @@ func (self *Store) GetAccount(id int, tx *sqlx.Tx) (*Account, error) {
func (self *Store) FindAccountWithEmail(email string, tx *sqlx.Tx) (*Account, error) { func (self *Store) FindAccountWithEmail(email string, tx *sqlx.Tx) (*Account, error) {
a := &Account{} a := &Account{}
if err := tx.QueryRowx("select * from accounts where email = ?", email).StructScan(a); err != nil { if err := tx.QueryRowx("select * from accounts where email = $1", email).StructScan(a); err != nil {
return nil, errors.Wrap(err, "error selecting account by email") return nil, errors.Wrap(err, "error selecting account by email")
} }
return a, nil return a, nil
@ -46,7 +42,7 @@ func (self *Store) FindAccountWithEmail(email string, tx *sqlx.Tx) (*Account, er
func (self *Store) FindAccountWithToken(token string, tx *sqlx.Tx) (*Account, error) { func (self *Store) FindAccountWithToken(token string, tx *sqlx.Tx) (*Account, error) {
a := &Account{} a := &Account{}
if err := tx.QueryRowx("select * from accounts where token = ?", token).StructScan(a); err != nil { if err := tx.QueryRowx("select * from accounts where token = $1", token).StructScan(a); err != nil {
return nil, errors.Wrap(err, "error selecting account by token") return nil, errors.Wrap(err, "error selecting account by token")
} }
return a, nil return a, nil

View File

@ -13,24 +13,20 @@ type AccountRequest struct {
} }
func (self *Store) CreateAccountRequest(ar *AccountRequest, tx *sqlx.Tx) (int, error) { func (self *Store) CreateAccountRequest(ar *AccountRequest, tx *sqlx.Tx) (int, error) {
stmt, err := tx.Prepare("insert into account_requests (token, email, source_address) values (?, ?, ?)") stmt, err := tx.Prepare("insert into account_requests (token, email, source_address) values ($1, $2, $3) returning id")
if err != nil { if err != nil {
return 0, errors.Wrap(err, "error preparing account_requests insert statement") return 0, errors.Wrap(err, "error preparing account_requests insert statement")
} }
res, err := stmt.Exec(ar.Token, ar.Email, ar.SourceAddress) var id int
if err != nil { if err := stmt.QueryRow(ar.Token, ar.Email, ar.SourceAddress).Scan(&id); err != nil {
return 0, errors.Wrap(err, "error executing account_requests insert statement") return 0, errors.Wrap(err, "error executing account_requests insert statement")
} }
id, err := res.LastInsertId() return id, nil
if err != nil {
return 0, errors.Wrap(err, "error retrieving last account_requests insert id")
}
return int(id), nil
} }
func (self *Store) GetAccountRequest(id int, tx *sqlx.Tx) (*AccountRequest, error) { func (self *Store) GetAccountRequest(id int, tx *sqlx.Tx) (*AccountRequest, error) {
ar := &AccountRequest{} ar := &AccountRequest{}
if err := tx.QueryRowx("select * from account_requests where id = ?", id).StructScan(ar); err != nil { if err := tx.QueryRowx("select * from account_requests where id = $1", id).StructScan(ar); err != nil {
return nil, errors.Wrap(err, "error selecting account_request by id") return nil, errors.Wrap(err, "error selecting account_request by id")
} }
return ar, nil return ar, nil
@ -38,7 +34,7 @@ func (self *Store) GetAccountRequest(id int, tx *sqlx.Tx) (*AccountRequest, erro
func (self *Store) FindAccountRequestWithToken(token string, tx *sqlx.Tx) (*AccountRequest, error) { func (self *Store) FindAccountRequestWithToken(token string, tx *sqlx.Tx) (*AccountRequest, error) {
ar := &AccountRequest{} ar := &AccountRequest{}
if err := tx.QueryRowx("select * from account_requests where token = ?", token).StructScan(ar); err != nil { if err := tx.QueryRowx("select * from account_requests where token = $1", token).StructScan(ar); err != nil {
return nil, errors.Wrap(err, "error selecting account_request by token") return nil, errors.Wrap(err, "error selecting account_request by token")
} }
return ar, nil return ar, nil
@ -46,14 +42,14 @@ func (self *Store) FindAccountRequestWithToken(token string, tx *sqlx.Tx) (*Acco
func (self *Store) FindAccountRequestWithEmail(email string, tx *sqlx.Tx) (*AccountRequest, error) { func (self *Store) FindAccountRequestWithEmail(email string, tx *sqlx.Tx) (*AccountRequest, error) {
ar := &AccountRequest{} ar := &AccountRequest{}
if err := tx.QueryRowx("select * from account_requests where email = ?", email).StructScan(ar); err != nil { if err := tx.QueryRowx("select * from account_requests where email = $1", email).StructScan(ar); err != nil {
return nil, errors.Wrap(err, "error selecting account_request by email") return nil, errors.Wrap(err, "error selecting account_request by email")
} }
return ar, nil return ar, nil
} }
func (self *Store) DeleteAccountRequest(id int, tx *sqlx.Tx) error { func (self *Store) DeleteAccountRequest(id int, tx *sqlx.Tx) error {
stmt, err := tx.Prepare("delete from account_requests where id = ?") stmt, err := tx.Prepare("delete from account_requests where id = $1")
if err != nil { if err != nil {
return errors.Wrap(err, "error preparing account_requests delete statement") return errors.Wrap(err, "error preparing account_requests delete statement")
} }

View File

@ -15,31 +15,27 @@ type Environment struct {
} }
func (self *Store) CreateEnvironment(accountId int, i *Environment, tx *sqlx.Tx) (int, error) { func (self *Store) CreateEnvironment(accountId int, i *Environment, tx *sqlx.Tx) (int, error) {
stmt, err := tx.Prepare("insert into environments (account_id, description, host, address, z_id) values (?, ?, ?, ?, ?)") stmt, err := tx.Prepare("insert into environments (account_id, description, host, address, z_id) values ($1, $2, $3, $4, $5) returning id")
if err != nil { if err != nil {
return 0, errors.Wrap(err, "error preparing environments insert statement") return 0, errors.Wrap(err, "error preparing environments insert statement")
} }
res, err := stmt.Exec(accountId, i.Description, i.Host, i.Address, i.ZId) var id int
if err != nil { if err := stmt.QueryRow(accountId, i.Description, i.Host, i.Address, i.ZId).Scan(&id); err != nil {
return 0, errors.Wrap(err, "error executing environments insert statement") return 0, errors.Wrap(err, "error executing environments insert statement")
} }
id, err := res.LastInsertId() return id, nil
if err != nil {
return 0, errors.Wrap(err, "error retrieving last environments insert id")
}
return int(id), nil
} }
func (self *Store) GetEnvironment(id int, tx *sqlx.Tx) (*Environment, error) { func (self *Store) GetEnvironment(id int, tx *sqlx.Tx) (*Environment, error) {
i := &Environment{} i := &Environment{}
if err := tx.QueryRowx("select * from environments where id = ?", id).StructScan(i); err != nil { if err := tx.QueryRowx("select * from environments where id = $1", id).StructScan(i); err != nil {
return nil, errors.Wrap(err, "error selecting environment by id") return nil, errors.Wrap(err, "error selecting environment by id")
} }
return i, nil return i, nil
} }
func (self *Store) FindEnvironmentsForAccount(accountId int, tx *sqlx.Tx) ([]*Environment, error) { func (self *Store) FindEnvironmentsForAccount(accountId int, tx *sqlx.Tx) ([]*Environment, error) {
rows, err := tx.Queryx("select environments.* from environments where account_id = ?", accountId) rows, err := tx.Queryx("select environments.* from environments where account_id = $1", accountId)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error selecting environments by account id") return nil, errors.Wrap(err, "error selecting environments by account id")
} }
@ -55,7 +51,7 @@ func (self *Store) FindEnvironmentsForAccount(accountId int, tx *sqlx.Tx) ([]*En
} }
func (self *Store) DeleteEnvironment(id int, tx *sqlx.Tx) error { func (self *Store) DeleteEnvironment(id int, tx *sqlx.Tx) error {
stmt, err := tx.Prepare("delete from environments where id = ?") stmt, err := tx.Prepare("delete from environments where id = $1")
if err != nil { if err != nil {
return errors.Wrap(err, "error preparing environments delete statement") return errors.Wrap(err, "error preparing environments delete statement")
} }

View File

@ -15,24 +15,20 @@ type Service struct {
} }
func (self *Store) CreateService(envId int, svc *Service, tx *sqlx.Tx) (int, error) { func (self *Store) CreateService(envId int, svc *Service, tx *sqlx.Tx) (int, error) {
stmt, err := tx.Prepare("insert into services (environment_id, z_id, name, frontend, backend) values (?, ?, ?, ?, ?)") stmt, err := tx.Prepare("insert into services (environment_id, z_id, name, frontend, backend) values ($1, $2, $3, $4, $5) returning id")
if err != nil { if err != nil {
return 0, errors.Wrap(err, "error preparing services insert statement") return 0, errors.Wrap(err, "error preparing services insert statement")
} }
res, err := stmt.Exec(envId, svc.ZId, svc.Name, svc.Frontend, svc.Backend) var id int
if err != nil { if err := stmt.QueryRow(envId, svc.ZId, svc.Name, svc.Frontend, svc.Backend).Scan(&id); err != nil {
return 0, errors.Wrap(err, "error executing services insert statement") return 0, errors.Wrap(err, "error executing services insert statement")
} }
id, err := res.LastInsertId() return id, nil
if err != nil {
return 0, errors.Wrap(err, "error retrieving last services insert id")
}
return int(id), nil
} }
func (self *Store) GetService(id int, tx *sqlx.Tx) (*Service, error) { func (self *Store) GetService(id int, tx *sqlx.Tx) (*Service, error) {
svc := &Service{} svc := &Service{}
if err := tx.QueryRowx("select * from services where id = ?", id).StructScan(svc); err != nil { if err := tx.QueryRowx("select * from services where id = $1", id).StructScan(svc); err != nil {
return nil, errors.Wrap(err, "error selecting service by id") return nil, errors.Wrap(err, "error selecting service by id")
} }
return svc, nil return svc, nil
@ -55,7 +51,7 @@ func (self *Store) GetAllServices(tx *sqlx.Tx) ([]*Service, error) {
} }
func (self *Store) FindServicesForEnvironment(envId int, tx *sqlx.Tx) ([]*Service, error) { func (self *Store) FindServicesForEnvironment(envId int, tx *sqlx.Tx) ([]*Service, error) {
rows, err := tx.Queryx("select services.* from services where environment_id = ?", envId) rows, err := tx.Queryx("select services.* from services where environment_id = $1", envId)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error selecting services by environment id") return nil, errors.Wrap(err, "error selecting services by environment id")
} }
@ -71,7 +67,7 @@ func (self *Store) FindServicesForEnvironment(envId int, tx *sqlx.Tx) ([]*Servic
} }
func (self *Store) UpdateService(svc *Service, tx *sqlx.Tx) error { func (self *Store) UpdateService(svc *Service, tx *sqlx.Tx) error {
sql := "update services set z_id = ?, name = ?, frontend = ?, backend = ?, updated_at = strftime('%Y-%m-%d %H:%M:%f', 'now') where id = ?" sql := "update services set z_id = $1, name = $2, frontend = $3, backend = $4, updated_at = strftime('%Y-%m-%d %H:%M:%f', 'now') where id = $5"
stmt, err := tx.Prepare(sql) stmt, err := tx.Prepare(sql)
if err != nil { if err != nil {
return errors.Wrap(err, "error preparing services update statement") return errors.Wrap(err, "error preparing services update statement")
@ -84,7 +80,7 @@ func (self *Store) UpdateService(svc *Service, tx *sqlx.Tx) error {
} }
func (self *Store) DeleteService(id int, tx *sqlx.Tx) error { func (self *Store) DeleteService(id int, tx *sqlx.Tx) error {
stmt, err := tx.Prepare("delete from services where id = ?") stmt, err := tx.Prepare("delete from services where id = $1")
if err != nil { if err != nil {
return errors.Wrap(err, "error preparing services delete statement") return errors.Wrap(err, "error preparing services delete statement")
} }

View File

@ -4,7 +4,7 @@
-- accounts -- accounts
-- --
create table accounts ( create table accounts (
id integer primary key, id serial primary key,
email varchar(1024) not null unique, email varchar(1024) not null unique,
password char(128) not null, password char(128) not null,
token varchar(32) not null unique, token varchar(32) not null unique,
@ -20,7 +20,7 @@ create table accounts (
-- account_requests -- account_requests
-- --
create table account_requests ( create table account_requests (
id integer primary key, id serial primary key,
token varchar(32) not null unique, token varchar(32) not null unique,
email varchar(1024) not null unique, email varchar(1024) not null unique,
source_address varchar(64) not null, source_address varchar(64) not null,
@ -32,7 +32,7 @@ create table account_requests (
-- environments -- environments
-- --
create table environments ( create table environments (
id integer primary key, id serial primary key,
account_id integer constraint fk_accounts_identities references accounts on delete cascade, account_id integer constraint fk_accounts_identities references accounts on delete cascade,
description text, description text,
host varchar(256), host varchar(256),
@ -48,7 +48,7 @@ create table environments (
-- services -- services
-- --
create table services ( create table services (
id integer primary key, id serial primary key,
environment_id integer constraint fk_environments_services references environments on delete cascade, environment_id integer constraint fk_environments_services references environments on delete cascade,
z_id varchar(32) not null unique, z_id varchar(32) not null unique,
name varchar(32) not null unique, name varchar(32) not null unique,

View File

@ -4,7 +4,9 @@ import (
"fmt" "fmt"
"github.com/iancoleman/strcase" "github.com/iancoleman/strcase"
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
postgresql_schema "github.com/openziti-test-kitchen/zrok/controller/store/sql/postgresql"
sqlite3_schema "github.com/openziti-test-kitchen/zrok/controller/store/sql/sqlite3" sqlite3_schema "github.com/openziti-test-kitchen/zrok/controller/store/sql/sqlite3"
"github.com/pkg/errors" "github.com/pkg/errors"
migrate "github.com/rubenv/sql-migrate" migrate "github.com/rubenv/sql-migrate"
@ -20,6 +22,7 @@ type Model struct {
type Config struct { type Config struct {
Path string Path string
Type string
} }
type Store struct { type Store struct {
@ -28,15 +31,27 @@ type Store struct {
} }
func Open(cfg *Config) (*Store, error) { func Open(cfg *Config) (*Store, error) {
dbx, err := sqlx.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", cfg.Path)) var dbx *sqlx.DB
if err != nil { var err error
return nil, errors.Wrapf(err, "error opening database '%v'", cfg.Path)
switch cfg.Type {
case "sqlite3":
dbx, err = sqlx.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", cfg.Path))
if err != nil {
return nil, errors.Wrapf(err, "error opening database '%v'", cfg.Path)
}
dbx.DB.SetMaxOpenConns(1)
case "postgres":
dbx, err = sqlx.Connect("postgres", cfg.Path)
if err != nil {
return nil, errors.Wrapf(err, "error opening database '%v'", cfg.Path)
}
} }
dbx.DB.SetMaxOpenConns(1)
logrus.Infof("opened database '%v'", cfg.Path) logrus.Infof("opened database '%v'", cfg.Path)
dbx.MapperFunc(strcase.ToSnake) dbx.MapperFunc(strcase.ToSnake)
store := &Store{cfg: cfg, db: dbx} store := &Store{cfg: cfg, db: dbx}
if err := store.migrate(); err != nil { if err := store.migrate(cfg); err != nil {
return nil, errors.Wrapf(err, "error migrating database '%v'", cfg.Path) return nil, errors.Wrapf(err, "error migrating database '%v'", cfg.Path)
} }
return store, nil return store, nil
@ -50,16 +65,31 @@ func (self *Store) Close() error {
return self.db.Close() return self.db.Close()
} }
func (self *Store) migrate() error { func (self *Store) migrate(cfg *Config) error {
migrations := &migrate.EmbedFileSystemMigrationSource{ switch cfg.Type {
FileSystem: sqlite3_schema.FS, case "sqlite3":
Root: "/", migrations := &migrate.EmbedFileSystemMigrationSource{
FileSystem: sqlite3_schema.FS,
Root: "/",
}
migrate.SetTable("migrations")
n, err := migrate.Exec(self.db.DB, "sqlite3", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "error running migrations")
}
logrus.Infof("applied %d migrations", n)
case "postgres":
migrations := &migrate.EmbedFileSystemMigrationSource{
FileSystem: postgresql_schema.FS,
Root: "/",
}
migrate.SetTable("migrations")
n, err := migrate.Exec(self.db.DB, "postgres", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "error running migrations")
}
logrus.Infof("applied %d migrations", n)
} }
migrate.SetTable("migrations")
n, err := migrate.Exec(self.db.DB, "sqlite3", migrations, migrate.Up)
if err != nil {
return errors.Wrap(err, "error running migrations")
}
logrus.Infof("applied %d migrations", n)
return nil return nil
} }

1
go.mod
View File

@ -17,6 +17,7 @@ require (
github.com/jaevor/go-nanoid v1.3.0 github.com/jaevor/go-nanoid v1.3.0
github.com/jessevdk/go-flags v1.5.0 github.com/jessevdk/go-flags v1.5.0
github.com/jmoiron/sqlx v1.3.5 github.com/jmoiron/sqlx v1.3.5
github.com/lib/pq v1.10.0
github.com/mattn/go-sqlite3 v1.14.14 github.com/mattn/go-sqlite3 v1.14.14
github.com/michaelquigley/cf v0.0.12 github.com/michaelquigley/cf v0.0.12
github.com/michaelquigley/pfxlog v0.6.9 github.com/michaelquigley/pfxlog v0.6.9