pr comments

This commit is contained in:
Cam Otts 2023-01-09 13:23:02 -06:00
parent e5b749be71
commit 0734e7b511
No known key found for this signature in database
GPG Key ID: 367B7C7EBD84A8BD
10 changed files with 69 additions and 59 deletions

View File

@ -2,7 +2,6 @@ package main
import ( import (
"fmt" "fmt"
"strconv"
"github.com/jaevor/go-nanoid" "github.com/jaevor/go-nanoid"
"github.com/openziti-test-kitchen/zrok/rest_client_zrok/invite" "github.com/openziti-test-kitchen/zrok/rest_client_zrok/invite"
@ -12,36 +11,32 @@ import (
) )
func init() { func init() {
rootCmd.AddCommand(newGenerateCommand().cmd) adminCmd.AddCommand(newGenerateCommand().cmd)
} }
type generateCommand struct { type generateCommand struct {
cmd *cobra.Command cmd *cobra.Command
amount int
} }
func newGenerateCommand() *generateCommand { func newGenerateCommand() *generateCommand {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "generate <optional-amount>", Use: "generate",
Short: "Generate invite tokens (default: 5)", Short: "Generate invite tokens (default: 5)",
Args: cobra.RangeArgs(0, 1), Args: cobra.ExactArgs(0),
} }
command := &generateCommand{cmd: cmd} command := &generateCommand{cmd: cmd}
cmd.Run = command.run cmd.Run = command.run
cmd.Flags().IntVar(&command.amount, "amount", 5, "Amount of tokens to generate")
return command return command
} }
func (cmd *generateCommand) run(_ *cobra.Command, args []string) { func (cmd *generateCommand) run(_ *cobra.Command, args []string) {
var iterations int64 = 5
if len(args) == 1 {
i, err := strconv.ParseInt(args[0], 10, 64)
if err != nil {
showError("unable to parse amount", err)
}
iterations = i
}
var err error var err error
tokens := make([]string, iterations) tokens := make([]string, cmd.amount)
for i := 0; i < int(iterations); i++ { for i := 0; i < int(cmd.amount); i++ {
tokens[i], err = createToken() tokens[i], err = createToken()
if err != nil { if err != nil {
showError("error creating token", err) showError("error creating token", err)

View File

@ -2,6 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"github.com/openziti-test-kitchen/zrok/rest_client_zrok/account" "github.com/openziti-test-kitchen/zrok/rest_client_zrok/account"
"github.com/openziti-test-kitchen/zrok/rest_model_zrok" "github.com/openziti-test-kitchen/zrok/rest_model_zrok"
"github.com/openziti-test-kitchen/zrok/util" "github.com/openziti-test-kitchen/zrok/util"
@ -15,7 +16,8 @@ func init() {
} }
type inviteCommand struct { type inviteCommand struct {
cmd *cobra.Command cmd *cobra.Command
Token string
} }
func newInviteCommand() *inviteCommand { func newInviteCommand() *inviteCommand {
@ -24,8 +26,12 @@ func newInviteCommand() *inviteCommand {
Short: "Invite a new user to zrok", Short: "Invite a new user to zrok",
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
} }
command := &inviteCommand{cmd: cmd} command := &inviteCommand{cmd: cmd}
cmd.Run = command.run cmd.Run = command.run
cmd.Flags().StringVar(&command.Token, "token", "", "Invite token required when Zrok running in token store mode")
return command return command
} }
@ -55,6 +61,7 @@ func (cmd *inviteCommand) run(_ *cobra.Command, _ []string) {
req := account.NewInviteParams() req := account.NewInviteParams()
req.Body = &rest_model_zrok.InviteRequest{ req.Body = &rest_model_zrok.InviteRequest{
Email: email, Email: email,
Token: cmd.Token,
} }
_, err = zrok.Account.Invite(req) _, err = zrok.Account.Invite(req)
if err != nil { if err != nil {

View File

@ -39,24 +39,22 @@ func (self *inviteHandler) Handle(params account.InviteParams) middleware.Respon
defer func() { _ = tx.Rollback() }() defer func() { _ = tx.Rollback() }()
if self.cfg.Registration.TokenStrategy == "store" { if self.cfg.Registration.TokenStrategy == "store" {
invite, err := str.GetInvite(tx) invite, err := str.GetInviteByToken(params.Body.Token, tx)
if err != nil { if err != nil {
logrus.Error(err) logrus.Error(err)
return account.NewInviteInternalServerError() return account.NewInviteBadRequest()
} }
invite.Status = store.INVITE_STATUS_TAKEN if err := str.DeleteInvite(invite.Id, tx); err != nil {
if err := str.UpdateInvite(invite, tx); err != nil {
logrus.Error(err)
return account.NewInviteInternalServerError()
}
token = invite.Token
} else {
token, err = createToken()
if err != nil {
logrus.Error(err) logrus.Error(err)
return account.NewInviteInternalServerError() return account.NewInviteInternalServerError()
} }
} }
token, err = createToken()
if err != nil {
logrus.Error(err)
return account.NewInviteInternalServerError()
}
ar := &store.AccountRequest{ ar := &store.AccountRequest{
Token: token, Token: token,
Email: params.Body.Email, Email: params.Body.Email,
@ -120,8 +118,7 @@ func (handler *inviteGenerateHandler) Handle(params invite.InviteGenerateParams)
invites := make([]*store.Invite, len(params.Body.Tokens)) invites := make([]*store.Invite, len(params.Body.Tokens))
for i, token := range params.Body.Tokens { for i, token := range params.Body.Tokens {
invites[i] = &store.Invite{ invites[i] = &store.Invite{
Token: token, Token: token,
Status: store.INVITE_STATUS_UNUSED,
} }
} }
tx, err := str.Begin() tx, err := str.Begin()

View File

@ -10,25 +10,16 @@ import (
type Invite struct { type Invite struct {
Model Model
Token string Token string
Status string `db:"token_status"`
} }
const (
INVITE_STATUS_UNUSED = "UNUSED"
INVITE_STATUS_TAKEN = "TAKEN"
)
func (str *Store) CreateInvites(invites []*Invite, tx *sqlx.Tx) error { func (str *Store) CreateInvites(invites []*Invite, tx *sqlx.Tx) error {
sql := "insert into invites (token, token_status) values %s" sql := "insert into invites (token) values %s"
invs := make([]any, len(invites)*2) invs := make([]any, len(invites))
queries := make([]string, len(invites)) queries := make([]string, len(invites))
ct := 1
for i, inv := range invites { for i, inv := range invites {
invs[i] = inv.Token invs[i] = inv.Token
invs[i+1] = inv.Status queries[i] = fmt.Sprintf("($%d)", i+1)
queries[i] = fmt.Sprintf("($%d, $%d)", ct, ct+1)
ct = ct + 2
} }
stmt, err := tx.Prepare(fmt.Sprintf(sql, strings.Join(queries, ","))) stmt, err := tx.Prepare(fmt.Sprintf(sql, strings.Join(queries, ",")))
if err != nil { if err != nil {
@ -40,22 +31,34 @@ func (str *Store) CreateInvites(invites []*Invite, tx *sqlx.Tx) error {
return nil return nil
} }
func (str *Store) GetInvite(tx *sqlx.Tx) (*Invite, error) { func (str *Store) GetInviteByToken(token string, tx *sqlx.Tx) (*Invite, error) {
invite := &Invite{} invite := &Invite{}
if err := tx.QueryRowx("select * from invites where token_status = $1 limit 1", INVITE_STATUS_UNUSED).StructScan(invite); err != nil { if err := tx.QueryRowx("select * from invites where token = $1", token).StructScan(invite); err != nil {
return nil, errors.Wrap(err, "error getting unused invite") return nil, errors.Wrap(err, "error getting unused invite")
} }
return invite, nil return invite, nil
} }
func (str *Store) UpdateInvite(invite *Invite, tx *sqlx.Tx) error { func (str *Store) UpdateInvite(invite *Invite, tx *sqlx.Tx) error {
stmt, err := tx.Prepare("update invites set token = $1, token_status = $2") stmt, err := tx.Prepare("update invites set token = $1")
if err != nil { if err != nil {
return errors.Wrap(err, "error perparing invites update statement") return errors.Wrap(err, "error perparing invites update statement")
} }
_, err = stmt.Exec(invite.Token, invite.Status) _, err = stmt.Exec(invite.Token)
if err != nil { if err != nil {
return errors.Wrap(err, "error executing invites update statement") return errors.Wrap(err, "error executing invites update statement")
} }
return nil return nil
} }
func (str *Store) DeleteInvite(id int, tx *sqlx.Tx) error {
stmt, err := tx.Prepare("delete from invites where id = $1")
if err != nil {
return errors.Wrap(err, "error preparing invites delete statement")
}
_, err = stmt.Exec(id)
if err != nil {
return errors.Wrap(err, "error executing invites delete statement")
}
return nil
}

View File

@ -5,12 +5,10 @@
--- ---
create table invites ( create table invites (
id serial primary key, id serial primary key,
token varchar(32) not null unique, token varchar(32) not null unique,
token_status varchar(1024) not null unique, created_at timestamptz not null default(current_timestamp),
created_at timestamp not null default(current_timestamp), updated_at timestamptz not null default(current_timestamp),
updated_at timestamp not null default(current_timestamp),
constraint chk_token check(token <> ''), constraint chk_token check(token <> '')
constraint chk_status check(token_status <> '')
); );

View File

@ -5,12 +5,10 @@
--- ---
create table invites ( create table invites (
id serial primary key, id integer primary key,
token varchar(32) not null unique, token string not null unique,
token_status varchar(1024) not null unique, created_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')),
created_at timestamp not null default(current_timestamp), updated_at datetime not null default(strftime('%Y-%m-%d %H:%M:%f', 'now')),
updated_at timestamp not null default(current_timestamp),
constraint chk_token check(token <> ''), constraint chk_token check(token <> ''),
constraint chk_status check(token_status <> '')
); );

View File

@ -19,6 +19,9 @@ type InviteRequest struct {
// email // email
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
// token
Token string `json:"token,omitempty"`
} }
// Validate validates this invite request // Validate validates this invite request

View File

@ -903,6 +903,9 @@ func init() {
"properties": { "properties": {
"email": { "email": {
"type": "string" "type": "string"
},
"token": {
"type": "string"
} }
} }
}, },
@ -2060,6 +2063,9 @@ func init() {
"properties": { "properties": {
"email": { "email": {
"type": "string" "type": "string"
},
"token": {
"type": "string"
} }
} }
}, },

View File

@ -578,6 +578,8 @@ definitions:
properties: properties:
email: email:
type: string type: string
token:
type: string
loginRequest: loginRequest:
type: object type: object

View File

@ -96,6 +96,7 @@
* @memberof module:types * @memberof module:types
* *
* @property {string} email * @property {string} email
* @property {string} token
*/ */
/** /**