From 40be626b3ab9883d1489e19d55aa1ce2e329877e Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sun, 21 Mar 2021 19:57:33 +0100 Subject: [PATCH] zrepl wait: poll-based implementation --- client/wait.go | 110 +++++++++++++++++++++++++++++++++++++++++++ daemon/control.go | 26 +++------- daemon/job/active.go | 58 ++++++++++++----------- 3 files changed, 146 insertions(+), 48 deletions(-) create mode 100644 client/wait.go diff --git a/client/wait.go b/client/wait.go new file mode 100644 index 0000000..f11331b --- /dev/null +++ b/client/wait.go @@ -0,0 +1,110 @@ +package client + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/kr/pretty" + "github.com/pkg/errors" + "github.com/spf13/pflag" + + "github.com/zrepl/zrepl/cli" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/daemon" + "github.com/zrepl/zrepl/daemon/job" +) + +var waitCmdArgs struct { + verbose bool + interval time.Duration +} + +var WaitCmd = &cli.Subcommand{ + Use: "wait [active JOB INVOCATION_ID WHAT]", + Short: "", + Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error { + return runWaitCmd(subcommand.Config(), args) + }, + SetupFlags: func(f *pflag.FlagSet) { + f.BoolVarP(&waitCmdArgs.verbose, "verbose", "v", false, "verbose output") + f.DurationVarP(&waitCmdArgs.interval, "poll-interval", "i", 100*time.Millisecond, "poll interval") + }, +} + +func runWaitCmd(config *config.Config, args []string) error { + + httpc, err := controlHttpClient(config.Global.Control.SockPath) + if err != nil { + return err + } + + if args[0] != "active" { + panic(args) + } + args = args[1:] + + jobName := args[0] + + invocationId, err := strconv.ParseUint(args[1], 10, 64) + if err != nil { + return errors.Wrap(err, "parse invocation id") + } + + waitWhat := args[2] + + doneErr := fmt.Errorf("done") + + var pollRequest job.ActiveSidePollRequest + + // updated by subsequent requests + pollRequest = job.ActiveSidePollRequest{ + InvocationId: invocationId, + What: waitWhat, + } + + pollOnce := func() error { + var res job.ActiveSidePollResponse + if waitCmdArgs.verbose { + pretty.Println("making poll request", pollRequest) + } + err = jsonRequestResponse(httpc, daemon.ControlJobEndpointPollActive, + struct { + Job string + job.ActiveSidePollRequest + }{ + Job: jobName, + ActiveSidePollRequest: pollRequest, + }, + &res, + ) + if err != nil { + return err + } + + if waitCmdArgs.verbose { + pretty.Println("got poll response", res) + } + + if res.Done { + return doneErr + } + + pollRequest.InvocationId = res.InvocationId + + return nil + } + + t := time.NewTicker(waitCmdArgs.interval) + for range t.C { + err := pollOnce() + if err == doneErr { + return nil + } else if err != nil { + return err + } + } + + return err +} diff --git a/daemon/control.go b/daemon/control.go index 40c60f2..9696302 100644 --- a/daemon/control.go +++ b/daemon/control.go @@ -77,7 +77,7 @@ const ( ControlJobEndpointVersion string = "/version" ControlJobEndpointStatus string = "/status" ControlJobEndpointSignal string = "/signal" - ControlJobEndpointWaitActive string = "/wait/active" + ControlJobEndpointPollActive string = "/poll/active" ) func (j *controlJob) Run(ctx context.Context) { @@ -131,11 +131,10 @@ func (j *controlJob) Run(ctx context.Context) { return s, nil }}) - mux.Handle(ControlJobEndpointWaitActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) { + mux.Handle(ControlJobEndpointPollActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) { type reqT struct { Job string - InvocationId uint64 - What string + job.ActiveSidePollRequest } var req reqT if decoder(&req) != nil { @@ -157,24 +156,11 @@ func (j *controlJob) Run(ctx context.Context) { return v, err } - cbCalled := make(chan struct{}) - err = ajo.AddActiveSideWaiter(req.InvocationId, req.What, func() { - log.WithField("request", req).Debug("active side waiter done") - close(cbCalled) - }) + res, err := ajo.Poll(req.ActiveSidePollRequest) - j.jobs.m.RUnlock() // unlock before waiting! - - if err != nil { - return struct{}{}, err - } - - select { - // TODO ctx with timeout! - case <-cbCalled: - return struct{}{}, nil - } + j.jobs.m.RUnlock() + return res, err }}}) mux.Handle(ControlJobEndpointSignal, diff --git a/daemon/job/active.go b/daemon/job/active.go index 5f23845..2c2a183 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -43,10 +43,10 @@ type ActiveSide struct { promBytesReplicated *prometheus.CounterVec // labels: filesystem promReplicationErrors prometheus.Gauge - tasksMtx sync.Mutex - tasks activeSideTasks - activeInvocationId uint64 // 0 <=> inactive - doneWaiters, nextWaiters []func() + tasksMtx sync.Mutex + tasks activeSideTasks + nextInvocationId uint64 + activeInvocationId uint64 // 0 <=> inactive } //go:generate enumer -type=ActiveSideState @@ -434,7 +434,7 @@ func (j *ActiveSide) Run(ctx context.Context) { defer endTask() go j.mode.RunPeriodic(periodicCtx, periodicDone) - invocationCount := 0 + j.nextInvocationId = 1 outer: for { @@ -448,58 +448,60 @@ outer: j.mode.ResetConnectBackoff() case <-periodicDone: } - invocationCount++ j.tasksMtx.Lock() - j.activeInvocationId = uint64(invocationCount) - j.doneWaiters = j.nextWaiters - j.nextWaiters = nil + j.activeInvocationId = j.nextInvocationId + j.nextInvocationId++ j.tasksMtx.Unlock() - invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", invocationCount)) + invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId)) j.do(invocationCtx) j.tasksMtx.Lock() j.activeInvocationId = 0 - for _, f := range j.doneWaiters { - go f() - } - j.doneWaiters = nil j.tasksMtx.Unlock() endSpan() } } -type AddActiveSideWaiterRequest struct { +type ActiveSidePollRequest struct { InvocationId uint64 What string } -func (j *ActiveSide) AddActiveSideWaiter(invocationId uint64, what string, cb func()) error { +type ActiveSidePollResponse struct { + Done bool + InvocationId uint64 +} + +func (j *ActiveSide) Poll(req ActiveSidePollRequest) (*ActiveSidePollResponse, error) { j.tasksMtx.Lock() defer j.tasksMtx.Unlock() - var targetQueue *[]func() - if invocationId == 0 { + waitForId := req.InvocationId + if req.InvocationId == 0 { + // handle the case where the client doesn't know what the current invocation id is if j.activeInvocationId != 0 { - targetQueue = &j.doneWaiters + waitForId = j.activeInvocationId } else { - targetQueue = &j.nextWaiters + waitForId = j.nextInvocationId } - } else if j.activeInvocationId == invocationId { - targetQueue = &j.nextWaiters - } else { - return fmt.Errorf("invocation %d is not the current invocation, current invocation is %d (0 means no active invocation); pass id '0' to wait for the next invocation", invocationId, j.activeInvocationId) } - switch what { + switch req.What { case "invocation-done": - *targetQueue = append(*targetQueue, cb) + var done bool + if j.activeInvocationId == 0 { + done = waitForId < j.nextInvocationId + } else { + done = waitForId < j.activeInvocationId + } + res := &ActiveSidePollResponse{Done: done, InvocationId: waitForId} + return res, nil default: - return fmt.Errorf("unknown wait target %q", what) + return nil, fmt.Errorf("unknown wait target %q", req.What) } - return nil } func (j *ActiveSide) do(ctx context.Context) {