mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-20 13:18:47 +01:00
zrepl wait: poll-based implementation
This commit is contained in:
parent
68b895d0bc
commit
40be626b3a
110
client/wait.go
Normal file
110
client/wait.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/kr/pretty"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
|
||||||
|
"github.com/zrepl/zrepl/cli"
|
||||||
|
"github.com/zrepl/zrepl/config"
|
||||||
|
"github.com/zrepl/zrepl/daemon"
|
||||||
|
"github.com/zrepl/zrepl/daemon/job"
|
||||||
|
)
|
||||||
|
|
||||||
|
var waitCmdArgs struct {
|
||||||
|
verbose bool
|
||||||
|
interval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var WaitCmd = &cli.Subcommand{
|
||||||
|
Use: "wait [active JOB INVOCATION_ID WHAT]",
|
||||||
|
Short: "",
|
||||||
|
Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error {
|
||||||
|
return runWaitCmd(subcommand.Config(), args)
|
||||||
|
},
|
||||||
|
SetupFlags: func(f *pflag.FlagSet) {
|
||||||
|
f.BoolVarP(&waitCmdArgs.verbose, "verbose", "v", false, "verbose output")
|
||||||
|
f.DurationVarP(&waitCmdArgs.interval, "poll-interval", "i", 100*time.Millisecond, "poll interval")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWaitCmd(config *config.Config, args []string) error {
|
||||||
|
|
||||||
|
httpc, err := controlHttpClient(config.Global.Control.SockPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if args[0] != "active" {
|
||||||
|
panic(args)
|
||||||
|
}
|
||||||
|
args = args[1:]
|
||||||
|
|
||||||
|
jobName := args[0]
|
||||||
|
|
||||||
|
invocationId, err := strconv.ParseUint(args[1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "parse invocation id")
|
||||||
|
}
|
||||||
|
|
||||||
|
waitWhat := args[2]
|
||||||
|
|
||||||
|
doneErr := fmt.Errorf("done")
|
||||||
|
|
||||||
|
var pollRequest job.ActiveSidePollRequest
|
||||||
|
|
||||||
|
// updated by subsequent requests
|
||||||
|
pollRequest = job.ActiveSidePollRequest{
|
||||||
|
InvocationId: invocationId,
|
||||||
|
What: waitWhat,
|
||||||
|
}
|
||||||
|
|
||||||
|
pollOnce := func() error {
|
||||||
|
var res job.ActiveSidePollResponse
|
||||||
|
if waitCmdArgs.verbose {
|
||||||
|
pretty.Println("making poll request", pollRequest)
|
||||||
|
}
|
||||||
|
err = jsonRequestResponse(httpc, daemon.ControlJobEndpointPollActive,
|
||||||
|
struct {
|
||||||
|
Job string
|
||||||
|
job.ActiveSidePollRequest
|
||||||
|
}{
|
||||||
|
Job: jobName,
|
||||||
|
ActiveSidePollRequest: pollRequest,
|
||||||
|
},
|
||||||
|
&res,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if waitCmdArgs.verbose {
|
||||||
|
pretty.Println("got poll response", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Done {
|
||||||
|
return doneErr
|
||||||
|
}
|
||||||
|
|
||||||
|
pollRequest.InvocationId = res.InvocationId
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t := time.NewTicker(waitCmdArgs.interval)
|
||||||
|
for range t.C {
|
||||||
|
err := pollOnce()
|
||||||
|
if err == doneErr {
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
@ -77,7 +77,7 @@ const (
|
|||||||
ControlJobEndpointVersion string = "/version"
|
ControlJobEndpointVersion string = "/version"
|
||||||
ControlJobEndpointStatus string = "/status"
|
ControlJobEndpointStatus string = "/status"
|
||||||
ControlJobEndpointSignal string = "/signal"
|
ControlJobEndpointSignal string = "/signal"
|
||||||
ControlJobEndpointWaitActive string = "/wait/active"
|
ControlJobEndpointPollActive string = "/poll/active"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (j *controlJob) Run(ctx context.Context) {
|
func (j *controlJob) Run(ctx context.Context) {
|
||||||
@ -131,11 +131,10 @@ func (j *controlJob) Run(ctx context.Context) {
|
|||||||
return s, nil
|
return s, nil
|
||||||
}})
|
}})
|
||||||
|
|
||||||
mux.Handle(ControlJobEndpointWaitActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) {
|
mux.Handle(ControlJobEndpointPollActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) {
|
||||||
type reqT struct {
|
type reqT struct {
|
||||||
Job string
|
Job string
|
||||||
InvocationId uint64
|
job.ActiveSidePollRequest
|
||||||
What string
|
|
||||||
}
|
}
|
||||||
var req reqT
|
var req reqT
|
||||||
if decoder(&req) != nil {
|
if decoder(&req) != nil {
|
||||||
@ -157,24 +156,11 @@ func (j *controlJob) Run(ctx context.Context) {
|
|||||||
return v, err
|
return v, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cbCalled := make(chan struct{})
|
res, err := ajo.Poll(req.ActiveSidePollRequest)
|
||||||
err = ajo.AddActiveSideWaiter(req.InvocationId, req.What, func() {
|
|
||||||
log.WithField("request", req).Debug("active side waiter done")
|
|
||||||
close(cbCalled)
|
|
||||||
})
|
|
||||||
|
|
||||||
j.jobs.m.RUnlock() // unlock before waiting!
|
j.jobs.m.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return struct{}{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
// TODO ctx with timeout!
|
|
||||||
case <-cbCalled:
|
|
||||||
return struct{}{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
return res, err
|
||||||
}}})
|
}}})
|
||||||
|
|
||||||
mux.Handle(ControlJobEndpointSignal,
|
mux.Handle(ControlJobEndpointSignal,
|
||||||
|
@ -45,8 +45,8 @@ type ActiveSide struct {
|
|||||||
|
|
||||||
tasksMtx sync.Mutex
|
tasksMtx sync.Mutex
|
||||||
tasks activeSideTasks
|
tasks activeSideTasks
|
||||||
|
nextInvocationId uint64
|
||||||
activeInvocationId uint64 // 0 <=> inactive
|
activeInvocationId uint64 // 0 <=> inactive
|
||||||
doneWaiters, nextWaiters []func()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:generate enumer -type=ActiveSideState
|
//go:generate enumer -type=ActiveSideState
|
||||||
@ -434,7 +434,7 @@ func (j *ActiveSide) Run(ctx context.Context) {
|
|||||||
defer endTask()
|
defer endTask()
|
||||||
go j.mode.RunPeriodic(periodicCtx, periodicDone)
|
go j.mode.RunPeriodic(periodicCtx, periodicDone)
|
||||||
|
|
||||||
invocationCount := 0
|
j.nextInvocationId = 1
|
||||||
|
|
||||||
outer:
|
outer:
|
||||||
for {
|
for {
|
||||||
@ -448,58 +448,60 @@ outer:
|
|||||||
j.mode.ResetConnectBackoff()
|
j.mode.ResetConnectBackoff()
|
||||||
case <-periodicDone:
|
case <-periodicDone:
|
||||||
}
|
}
|
||||||
invocationCount++
|
|
||||||
|
|
||||||
j.tasksMtx.Lock()
|
j.tasksMtx.Lock()
|
||||||
j.activeInvocationId = uint64(invocationCount)
|
j.activeInvocationId = j.nextInvocationId
|
||||||
j.doneWaiters = j.nextWaiters
|
j.nextInvocationId++
|
||||||
j.nextWaiters = nil
|
|
||||||
j.tasksMtx.Unlock()
|
j.tasksMtx.Unlock()
|
||||||
|
|
||||||
invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", invocationCount))
|
invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
|
||||||
j.do(invocationCtx)
|
j.do(invocationCtx)
|
||||||
|
|
||||||
j.tasksMtx.Lock()
|
j.tasksMtx.Lock()
|
||||||
j.activeInvocationId = 0
|
j.activeInvocationId = 0
|
||||||
for _, f := range j.doneWaiters {
|
|
||||||
go f()
|
|
||||||
}
|
|
||||||
j.doneWaiters = nil
|
|
||||||
j.tasksMtx.Unlock()
|
j.tasksMtx.Unlock()
|
||||||
|
|
||||||
endSpan()
|
endSpan()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type AddActiveSideWaiterRequest struct {
|
type ActiveSidePollRequest struct {
|
||||||
InvocationId uint64
|
InvocationId uint64
|
||||||
What string
|
What string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *ActiveSide) AddActiveSideWaiter(invocationId uint64, what string, cb func()) error {
|
type ActiveSidePollResponse struct {
|
||||||
|
Done bool
|
||||||
|
InvocationId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *ActiveSide) Poll(req ActiveSidePollRequest) (*ActiveSidePollResponse, error) {
|
||||||
j.tasksMtx.Lock()
|
j.tasksMtx.Lock()
|
||||||
defer j.tasksMtx.Unlock()
|
defer j.tasksMtx.Unlock()
|
||||||
|
|
||||||
var targetQueue *[]func()
|
waitForId := req.InvocationId
|
||||||
if invocationId == 0 {
|
if req.InvocationId == 0 {
|
||||||
|
// handle the case where the client doesn't know what the current invocation id is
|
||||||
if j.activeInvocationId != 0 {
|
if j.activeInvocationId != 0 {
|
||||||
targetQueue = &j.doneWaiters
|
waitForId = j.activeInvocationId
|
||||||
} else {
|
} else {
|
||||||
targetQueue = &j.nextWaiters
|
waitForId = j.nextInvocationId
|
||||||
}
|
}
|
||||||
} else if j.activeInvocationId == invocationId {
|
|
||||||
targetQueue = &j.nextWaiters
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("invocation %d is not the current invocation, current invocation is %d (0 means no active invocation); pass id '0' to wait for the next invocation", invocationId, j.activeInvocationId)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch what {
|
switch req.What {
|
||||||
case "invocation-done":
|
case "invocation-done":
|
||||||
*targetQueue = append(*targetQueue, cb)
|
var done bool
|
||||||
default:
|
if j.activeInvocationId == 0 {
|
||||||
return fmt.Errorf("unknown wait target %q", what)
|
done = waitForId < j.nextInvocationId
|
||||||
|
} else {
|
||||||
|
done = waitForId < j.activeInvocationId
|
||||||
|
}
|
||||||
|
res := &ActiveSidePollResponse{Done: done, InvocationId: waitForId}
|
||||||
|
return res, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown wait target %q", req.What)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *ActiveSide) do(ctx context.Context) {
|
func (j *ActiveSide) do(ctx context.Context) {
|
||||||
|
Loading…
Reference in New Issue
Block a user