zrepl wait: poll-based implementation

This commit is contained in:
Christian Schwarz 2021-03-21 19:57:33 +01:00
parent 68b895d0bc
commit 40be626b3a
3 changed files with 146 additions and 48 deletions

110
client/wait.go Normal file
View File

@ -0,0 +1,110 @@
package client
import (
"context"
"fmt"
"strconv"
"time"
"github.com/kr/pretty"
"github.com/pkg/errors"
"github.com/spf13/pflag"
"github.com/zrepl/zrepl/cli"
"github.com/zrepl/zrepl/config"
"github.com/zrepl/zrepl/daemon"
"github.com/zrepl/zrepl/daemon/job"
)
var waitCmdArgs struct {
verbose bool
interval time.Duration
}
var WaitCmd = &cli.Subcommand{
Use: "wait [active JOB INVOCATION_ID WHAT]",
Short: "",
Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error {
return runWaitCmd(subcommand.Config(), args)
},
SetupFlags: func(f *pflag.FlagSet) {
f.BoolVarP(&waitCmdArgs.verbose, "verbose", "v", false, "verbose output")
f.DurationVarP(&waitCmdArgs.interval, "poll-interval", "i", 100*time.Millisecond, "poll interval")
},
}
func runWaitCmd(config *config.Config, args []string) error {
httpc, err := controlHttpClient(config.Global.Control.SockPath)
if err != nil {
return err
}
if args[0] != "active" {
panic(args)
}
args = args[1:]
jobName := args[0]
invocationId, err := strconv.ParseUint(args[1], 10, 64)
if err != nil {
return errors.Wrap(err, "parse invocation id")
}
waitWhat := args[2]
doneErr := fmt.Errorf("done")
var pollRequest job.ActiveSidePollRequest
// updated by subsequent requests
pollRequest = job.ActiveSidePollRequest{
InvocationId: invocationId,
What: waitWhat,
}
pollOnce := func() error {
var res job.ActiveSidePollResponse
if waitCmdArgs.verbose {
pretty.Println("making poll request", pollRequest)
}
err = jsonRequestResponse(httpc, daemon.ControlJobEndpointPollActive,
struct {
Job string
job.ActiveSidePollRequest
}{
Job: jobName,
ActiveSidePollRequest: pollRequest,
},
&res,
)
if err != nil {
return err
}
if waitCmdArgs.verbose {
pretty.Println("got poll response", res)
}
if res.Done {
return doneErr
}
pollRequest.InvocationId = res.InvocationId
return nil
}
t := time.NewTicker(waitCmdArgs.interval)
for range t.C {
err := pollOnce()
if err == doneErr {
return nil
} else if err != nil {
return err
}
}
return err
}

View File

@ -77,7 +77,7 @@ const (
ControlJobEndpointVersion string = "/version" ControlJobEndpointVersion string = "/version"
ControlJobEndpointStatus string = "/status" ControlJobEndpointStatus string = "/status"
ControlJobEndpointSignal string = "/signal" ControlJobEndpointSignal string = "/signal"
ControlJobEndpointWaitActive string = "/wait/active" ControlJobEndpointPollActive string = "/poll/active"
) )
func (j *controlJob) Run(ctx context.Context) { func (j *controlJob) Run(ctx context.Context) {
@ -131,11 +131,10 @@ func (j *controlJob) Run(ctx context.Context) {
return s, nil return s, nil
}}) }})
mux.Handle(ControlJobEndpointWaitActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) { mux.Handle(ControlJobEndpointPollActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) {
type reqT struct { type reqT struct {
Job string Job string
InvocationId uint64 job.ActiveSidePollRequest
What string
} }
var req reqT var req reqT
if decoder(&req) != nil { if decoder(&req) != nil {
@ -157,24 +156,11 @@ func (j *controlJob) Run(ctx context.Context) {
return v, err return v, err
} }
cbCalled := make(chan struct{}) res, err := ajo.Poll(req.ActiveSidePollRequest)
err = ajo.AddActiveSideWaiter(req.InvocationId, req.What, func() {
log.WithField("request", req).Debug("active side waiter done")
close(cbCalled)
})
j.jobs.m.RUnlock() // unlock before waiting! j.jobs.m.RUnlock()
if err != nil {
return struct{}{}, err
}
select {
// TODO ctx with timeout!
case <-cbCalled:
return struct{}{}, nil
}
return res, err
}}}) }}})
mux.Handle(ControlJobEndpointSignal, mux.Handle(ControlJobEndpointSignal,

View File

@ -45,8 +45,8 @@ type ActiveSide struct {
tasksMtx sync.Mutex tasksMtx sync.Mutex
tasks activeSideTasks tasks activeSideTasks
nextInvocationId uint64
activeInvocationId uint64 // 0 <=> inactive activeInvocationId uint64 // 0 <=> inactive
doneWaiters, nextWaiters []func()
} }
//go:generate enumer -type=ActiveSideState //go:generate enumer -type=ActiveSideState
@ -434,7 +434,7 @@ func (j *ActiveSide) Run(ctx context.Context) {
defer endTask() defer endTask()
go j.mode.RunPeriodic(periodicCtx, periodicDone) go j.mode.RunPeriodic(periodicCtx, periodicDone)
invocationCount := 0 j.nextInvocationId = 1
outer: outer:
for { for {
@ -448,58 +448,60 @@ outer:
j.mode.ResetConnectBackoff() j.mode.ResetConnectBackoff()
case <-periodicDone: case <-periodicDone:
} }
invocationCount++
j.tasksMtx.Lock() j.tasksMtx.Lock()
j.activeInvocationId = uint64(invocationCount) j.activeInvocationId = j.nextInvocationId
j.doneWaiters = j.nextWaiters j.nextInvocationId++
j.nextWaiters = nil
j.tasksMtx.Unlock() j.tasksMtx.Unlock()
invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", invocationCount)) invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
j.do(invocationCtx) j.do(invocationCtx)
j.tasksMtx.Lock() j.tasksMtx.Lock()
j.activeInvocationId = 0 j.activeInvocationId = 0
for _, f := range j.doneWaiters {
go f()
}
j.doneWaiters = nil
j.tasksMtx.Unlock() j.tasksMtx.Unlock()
endSpan() endSpan()
} }
} }
type AddActiveSideWaiterRequest struct { type ActiveSidePollRequest struct {
InvocationId uint64 InvocationId uint64
What string What string
} }
func (j *ActiveSide) AddActiveSideWaiter(invocationId uint64, what string, cb func()) error { type ActiveSidePollResponse struct {
Done bool
InvocationId uint64
}
func (j *ActiveSide) Poll(req ActiveSidePollRequest) (*ActiveSidePollResponse, error) {
j.tasksMtx.Lock() j.tasksMtx.Lock()
defer j.tasksMtx.Unlock() defer j.tasksMtx.Unlock()
var targetQueue *[]func() waitForId := req.InvocationId
if invocationId == 0 { if req.InvocationId == 0 {
// handle the case where the client doesn't know what the current invocation id is
if j.activeInvocationId != 0 { if j.activeInvocationId != 0 {
targetQueue = &j.doneWaiters waitForId = j.activeInvocationId
} else { } else {
targetQueue = &j.nextWaiters waitForId = j.nextInvocationId
} }
} else if j.activeInvocationId == invocationId {
targetQueue = &j.nextWaiters
} else {
return fmt.Errorf("invocation %d is not the current invocation, current invocation is %d (0 means no active invocation); pass id '0' to wait for the next invocation", invocationId, j.activeInvocationId)
} }
switch what { switch req.What {
case "invocation-done": case "invocation-done":
*targetQueue = append(*targetQueue, cb) var done bool
default: if j.activeInvocationId == 0 {
return fmt.Errorf("unknown wait target %q", what) done = waitForId < j.nextInvocationId
} else {
done = waitForId < j.activeInvocationId
}
res := &ActiveSidePollResponse{Done: done, InvocationId: waitForId}
return res, nil
default:
return nil, fmt.Errorf("unknown wait target %q", req.What)
} }
return nil
} }
func (j *ActiveSide) do(ctx context.Context) { func (j *ActiveSide) do(ctx context.Context) {