diff --git a/client/signal.go b/client/signal.go deleted file mode 100644 index 5a3ec79..0000000 --- a/client/signal.go +++ /dev/null @@ -1,65 +0,0 @@ -package client - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/pkg/errors" - - "github.com/zrepl/zrepl/cli" - "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon" - "github.com/zrepl/zrepl/daemon/job" -) - -var SignalCmd = &cli.Subcommand{ - Use: "signal JOB [replication|reset|snapshot]", - Short: "run a job replication, abort its current invocation, run a snapshot job", - Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error { - return runSignalCmd(subcommand.Config(), args) - }, -} - -func runSignalCmd(config *config.Config, args []string) error { - if len(args) != 2 { - return errors.Errorf("Expected 2 arguments: [replication|reset|snapshot] JOB") - } - - httpc, err := controlHttpClient(config.Global.Control.SockPath) - if err != nil { - return err - } - - jobName := args[0] - what := args[1] - - var res job.ActiveSideSignalResponse - err = jsonRequestResponse(httpc, daemon.ControlJobEndpointSignalActive, - struct { - Job string - job.ActiveSideSignalRequest - }{ - Job: jobName, - ActiveSideSignalRequest: job.ActiveSideSignalRequest{ - What: what, - }, - }, - &res, - ) - - pollRequest := daemon.ControlJobEndpointSignalActiveRequest{ - Job: jobName, - ActiveSidePollRequest: job.ActiveSidePollRequest{ - InvocationId: res.InvocationId, - What: what, - }, - } - - j, err := json.Marshal(pollRequest) - if err != nil { - panic(err) - } - fmt.Println(string(j)) - return err -} diff --git a/client/status/client/client.go b/client/status/client/client.go index d495a18..2096803 100644 --- a/client/status/client/client.go +++ b/client/status/client/client.go @@ -44,13 +44,13 @@ func (c *Client) StatusRaw() ([]byte, error) { } func (c *Client) signal(jobName, sig string) error { - return jsonRequestResponse(c.h, daemon.ControlJobEndpointSignalActive, + return jsonRequestResponse(c.h, daemon.ControlJobEndpointTriggerActive, struct { Job string - job.ActiveSideSignalRequest + job.ActiveSideTriggerRequest }{ Job: jobName, - ActiveSideSignalRequest: job.ActiveSideSignalRequest{ + ActiveSideTriggerRequest: job.ActiveSideTriggerRequest{ What: sig, }, }, diff --git a/client/trigger.go b/client/trigger.go new file mode 100644 index 0000000..a002915 --- /dev/null +++ b/client/trigger.go @@ -0,0 +1,94 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/pkg/errors" + + "github.com/zrepl/zrepl/cli" + "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/daemon" + "github.com/zrepl/zrepl/daemon/job" +) + +var TriggerCmd = &cli.Subcommand{ + Use: "trigger JOB [replication|snapshot]", + Short: "", + Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error { + return runTriggerCmd(subcommand.Config(), args) + }, +} + +type TriggerToken struct { + // TODO version, daemon invocation id, etc. + Job string + InvocationId uint64 +} + +func (t TriggerToken) Encode() string { + j, err := json.Marshal(t) + if err != nil { + panic(err) + } + return string(j) +} + +func (t *TriggerToken) Decode(s string) error { + return json.Unmarshal([]byte(s), t) +} + +func (t TriggerToken) ToReset() daemon.ControlJobEndpointResetActiveRequest { + return daemon.ControlJobEndpointResetActiveRequest{ + Job: t.Job, + ActiveSideResetRequest: job.ActiveSideResetRequest{ + InvocationId: t.InvocationId, + }, + } +} + +func (t TriggerToken) ToWait() daemon.ControlJobEndpointWaitActiveRequest { + return daemon.ControlJobEndpointWaitActiveRequest{ + Job: t.Job, + ActiveSidePollRequest: job.ActiveSidePollRequest{ + InvocationId: t.InvocationId, + }, + } +} + +func runTriggerCmd(config *config.Config, args []string) error { + if len(args) != 2 { + return errors.Errorf("Expected 2 arguments: [replication|reset|snapshot] JOB") + } + + httpc, err := controlHttpClient(config.Global.Control.SockPath) + if err != nil { + return err + } + + jobName := args[0] + what := args[1] + + var res job.ActiveSideSignalResponse + err = jsonRequestResponse(httpc, daemon.ControlJobEndpointTriggerActive, + struct { + Job string + job.ActiveSideTriggerRequest + }{ + Job: jobName, + ActiveSideTriggerRequest: job.ActiveSideTriggerRequest{ + What: what, + }, + }, + &res, + ) + + token := TriggerToken{ + Job: jobName, + InvocationId: res.InvocationId, + } + + fmt.Println(token.Encode()) + return err +} diff --git a/client/wait.go b/client/wait.go index 397df6c..92c965b 100644 --- a/client/wait.go +++ b/client/wait.go @@ -2,7 +2,6 @@ package client import ( "context" - "encoding/json" "fmt" "strconv" "time" @@ -24,7 +23,7 @@ var waitCmdArgs struct { } var WaitCmd = &cli.Subcommand{ - Use: "wait [-t TOKEN | [replication|snapshotting|prune_sender|prune_receiver JOB]]", + Use: "wait [-t TOKEN | JOB INVOCATION [replication|snapshotting|prune_sender|prune_receiver]]", Short: "", Run: func(ctx context.Context, subcommand *cli.Subcommand, args []string) error { return runWaitCmd(subcommand.Config(), args) @@ -43,22 +42,16 @@ func runWaitCmd(config *config.Config, args []string) error { return err } - var pollRequest daemon.ControlJobEndpointSignalActiveRequest + var req daemon.ControlJobEndpointWaitActiveRequest if waitCmdArgs.token != "" { - if len(args) != 0 { - return fmt.Errorf("-t and regular usage is mutually exclusive") - } - err := json.Unmarshal([]byte(waitCmdArgs.token), &pollRequest) + var token TriggerToken + err := token.Decode(resetCmdArgs.token) if err != nil { - return errors.Wrap(err, "cannot unmarshal token") + return errors.Wrap(err, "cannot decode token") } + req = token.ToWait() } else { - if args[0] != "active" { - panic(args) - } - args = args[1:] - jobName := args[0] invocationId, err := strconv.ParseUint(args[1], 10, 64) @@ -66,14 +59,11 @@ func runWaitCmd(config *config.Config, args []string) error { return errors.Wrap(err, "parse invocation id") } - waitWhat := args[2] - // updated by subsequent requests - pollRequest = daemon.ControlJobEndpointSignalActiveRequest{ + req = daemon.ControlJobEndpointWaitActiveRequest{ Job: jobName, ActiveSidePollRequest: job.ActiveSidePollRequest{ InvocationId: invocationId, - What: waitWhat, }, } } @@ -83,10 +73,10 @@ func runWaitCmd(config *config.Config, args []string) error { pollOnce := func() error { var res job.ActiveSidePollResponse if waitCmdArgs.verbose { - pretty.Println("making poll request", pollRequest) + pretty.Println("making poll request", req) } err = jsonRequestResponse(httpc, daemon.ControlJobEndpointPollActive, - pollRequest, + req, &res, ) if err != nil { @@ -101,7 +91,7 @@ func runWaitCmd(config *config.Config, args []string) error { return doneErr } - pollRequest.InvocationId = res.InvocationId + req.InvocationId = res.InvocationId return nil } diff --git a/daemon/control.go b/daemon/control.go index 38810e6..9b7822f 100644 --- a/daemon/control.go +++ b/daemon/control.go @@ -73,14 +73,25 @@ func (j *controlJob) RegisterMetrics(registerer prometheus.Registerer) { } const ( - ControlJobEndpointPProf string = "/debug/pprof" - ControlJobEndpointVersion string = "/version" - ControlJobEndpointStatus string = "/status" - ControlJobEndpointSignalActive string = "/signal/active" - ControlJobEndpointPollActive string = "/poll/active" + ControlJobEndpointPProf string = "/debug/pprof" + ControlJobEndpointVersion string = "/version" + ControlJobEndpointStatus string = "/status" + ControlJobEndpointTriggerActive string = "/signal/active" + ControlJobEndpointPollActive string = "/poll/active" + ControlJobEndpointResetActive string = "/reset/active" ) -type ControlJobEndpointSignalActiveRequest struct { +type ControlJobEndpointTriggerActiveRequest struct { + Job string + job.ActiveSideTriggerRequest +} + +type ControlJobEndpointResetActiveRequest struct { + Job string + job.ActiveSideResetRequest +} + +type ControlJobEndpointWaitActiveRequest struct { Job string job.ActiveSidePollRequest } @@ -137,7 +148,7 @@ func (j *controlJob) Run(ctx context.Context) { }}) mux.Handle(ControlJobEndpointPollActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) { - var req ControlJobEndpointSignalActiveRequest + var req ControlJobEndpointWaitActiveRequest if decoder(&req) != nil { return nil, errors.Errorf("decode failed") } @@ -164,11 +175,11 @@ func (j *controlJob) Run(ctx context.Context) { return res, err }}}) - mux.Handle(ControlJobEndpointSignalActive, + mux.Handle(ControlJobEndpointTriggerActive, requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) { type reqT struct { Job string - job.ActiveSideSignalRequest + job.ActiveSideTriggerRequest } var req reqT if decoder(&req) != nil { @@ -192,7 +203,43 @@ func (j *controlJob) Run(ctx context.Context) { return v, err } - res, err := ajo.Signal(req.ActiveSideSignalRequest) + res, err := ajo.Trigger(req.ActiveSideTriggerRequest) + + j.jobs.m.RUnlock() + + return res, err + + }}}) + + mux.Handle(ControlJobEndpointResetActive, + requestLogger{log: log, handler: jsonRequestResponder{log, func(decoder jsonDecoder) (v interface{}, err error) { + type reqT struct { + Job string + job.ActiveSideResetRequest + } + var req reqT + if decoder(&req) != nil { + return nil, errors.Errorf("decode failed") + } + + // FIXME dedup the following code with ControlJobEndpointPollActive + + j.jobs.m.RLock() + + jo, ok := j.jobs.jobs[req.Job] + if !ok { + j.jobs.m.RUnlock() + return struct{}{}, fmt.Errorf("unknown job name %q", req.Job) + } + + ajo, ok := jo.(*job.ActiveSide) + if !ok { + v, err = struct{}{}, fmt.Errorf("job %q is not an active side (it's a %T)", jo.Name(), jo) + j.jobs.m.RUnlock() + return v, err + } + + res, err := ajo.Reset(req.ActiveSideResetRequest) j.jobs.m.RUnlock() diff --git a/daemon/job/active.go b/daemon/job/active.go index 8bae016..e9f4b0b 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -8,13 +8,11 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/common/log" "github.com/zrepl/zrepl/daemon/logging/trace" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/config" - "github.com/zrepl/zrepl/daemon/job/reset" "github.com/zrepl/zrepl/daemon/pruner" "github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/endpoint" @@ -46,7 +44,8 @@ type ActiveSide struct { tasks activeSideTasks nextInvocationId uint64 activeInvocationId uint64 // 0 <=> inactive - signal chan struct{} + trigger chan struct{} + reset chan uint64 } //go:generate enumer -type=ActiveSideState @@ -437,7 +436,8 @@ func (j *ActiveSide) Run(ctx context.Context) { wakePeriodic := make(chan struct{}) go j.mode.RunPeriodic(periodicCtx, wakePeriodic, periodicDone) - j.signal = make(chan struct{}) + j.trigger = make(chan struct{}) + j.reset = make(chan uint64) j.nextInvocationId = 1 outer: @@ -448,7 +448,7 @@ outer: log.WithError(ctx.Err()).Info("context") break outer - case <-j.signal: + case <-j.trigger: j.mode.ResetConnectBackoff() case <-periodicDone: } @@ -456,10 +456,38 @@ outer: j.tasksMtx.Lock() j.activeInvocationId = j.nextInvocationId j.nextInvocationId++ + thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below j.tasksMtx.Unlock() + // setup the invocation context invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId)) + invocationCtx, cancelInvocation := context.WithCancel(invocationCtx) + + // setup the goroutine that waits for task resets + // Task resets are converted into cancellations of the invocation context. + waitForResetCtx, stopWaitForReset := context.WithCancel(ctx) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-waitForResetCtx.Done(): + return + case reqResetInvocation := <-j.reset: + l := log.WithField("requested_invocation_id", reqResetInvocation). + WithField("this_invocation_id", thisInvocation) + if reqResetInvocation == thisInvocation { + l.Info("reset received, cancelling current invocation") + cancelInvocation() + } else { + l.Debug("received reset for invocation id that is not us, discarding request") + } + } + }() + j.do(invocationCtx) + stopWaitForReset() + wg.Wait() j.tasksMtx.Lock() j.activeInvocationId = 0 @@ -471,7 +499,6 @@ outer: type ActiveSidePollRequest struct { InvocationId uint64 - What string } type ActiveSidePollResponse struct { @@ -493,22 +520,52 @@ func (j *ActiveSide) Poll(req ActiveSidePollRequest) (*ActiveSidePollResponse, e } } - switch req.What { - case "invocation": - var done bool - if j.activeInvocationId == 0 { - done = waitForId < j.nextInvocationId - } else { - done = waitForId < j.activeInvocationId - } - res := &ActiveSidePollResponse{Done: done, InvocationId: waitForId} - return res, nil - default: - return nil, fmt.Errorf("unknown wait target %q", req.What) + var done bool + if j.activeInvocationId == 0 { + done = waitForId < j.nextInvocationId + } else { + done = waitForId < j.activeInvocationId } + res := &ActiveSidePollResponse{Done: done, InvocationId: waitForId} + return res, nil } -type ActiveSideSignalRequest struct { +type ActiveSideResetRequest struct { + InvocationId uint64 +} + +type ActiveSideResetResponse struct { + InvocationId uint64 +} + +func (j *ActiveSide) Reset(req ActiveSideResetRequest) (*ActiveSideResetResponse, error) { + j.tasksMtx.Lock() + defer j.tasksMtx.Unlock() + + resetId := req.InvocationId + if req.InvocationId == 0 { + // handle the case where the client doesn't know what the current invocation id is + resetId = j.activeInvocationId + } + + if resetId == 0 { + return nil, fmt.Errorf("no active invocation") + } + + if resetId != j.activeInvocationId { + return nil, fmt.Errorf("active invocation (%d) is not the invocation requested for reset (%d); (active invocation '0' indicates no active invocation)", j.activeInvocationId, resetId) + } + + // non-blocking send (.Run() must not hold mutex while waiting for resets) + select { + case j.reset <- resetId: + default: + } + + return &ActiveSideResetResponse{InvocationId: resetId}, nil +} + +type ActiveSideTriggerRequest struct { What string } @@ -516,7 +573,7 @@ type ActiveSideSignalResponse struct { InvocationId uint64 } -func (j *ActiveSide) Signal(req ActiveSideSignalRequest) (*ActiveSideSignalResponse, error) { +func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalResponse, error) { // switch req.What { // case "replication": // invocationId, err = j.jobs.doreplication(req.Name) @@ -539,7 +596,7 @@ func (j *ActiveSide) Signal(req ActiveSideSignalRequest) (*ActiveSideSignalRespo } // non-blocking send (.Run() must not hold mutex while waiting for signals) select { - case j.signal <- struct{}{}: + case j.trigger <- struct{}{}: default: } j.tasksMtx.Unlock() @@ -554,18 +611,6 @@ func (j *ActiveSide) do(ctx context.Context) { j.mode.ConnectEndpoints(ctx, j.connecter) defer j.mode.DisconnectEndpoints() - // allow cancellation of an invocation (this function) - ctx, cancelThisRun := context.WithCancel(ctx) - defer cancelThisRun() - go func() { - select { - case <-reset.Wait(ctx): - log.Info("reset received, cancelling current invocation") - cancelThisRun() - case <-ctx.Done(): - } - }() - sender, receiver := j.mode.SenderReceiver() { diff --git a/daemon/job/snapjob.go b/daemon/job/snapjob.go index be1d021..0985a7a 100644 --- a/daemon/job/snapjob.go +++ b/daemon/job/snapjob.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" - "github.com/zrepl/zrepl/daemon/job/doreplication" "github.com/zrepl/zrepl/daemon/logging/trace" "github.com/zrepl/zrepl/util/nodefault" @@ -118,7 +117,7 @@ outer: log.WithError(ctx.Err()).Info("context") break outer - case <-doreplication.Wait(ctx): + // case <-doreplication.Wait(ctx): case <-periodicDone: } invocationCount++ diff --git a/daemon/snapper/snapper.go b/daemon/snapper/snapper.go index 7ebb8d8..5e6f075 100644 --- a/daemon/snapper/snapper.go +++ b/daemon/snapper/snapper.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" - "github.com/zrepl/zrepl/daemon/job/dosnapshot" "github.com/zrepl/zrepl/daemon/logging/trace" "github.com/zrepl/zrepl/config" @@ -211,10 +210,10 @@ func syncUp(a args, u updater) state { return u(func(s *Snapper) { s.state = Planning }).sf() - case <-dosnapshot.Wait(a.ctx): - return u(func(s *Snapper) { - s.state = Planning - }).sf() + // case <-dosnapshot.Wait(a.ctx): + // return u(func(s *Snapper) { + // s.state = Planning + // }).sf() case <-a.ctx.Done(): return onMainCtxDone(a.ctx, u) } @@ -383,10 +382,10 @@ func wait(a args, u updater) state { return u(func(snapper *Snapper) { snapper.state = Planning }).sf() - case <-dosnapshot.Wait(a.ctx): - return u(func(snapper *Snapper) { - snapper.state = Planning - }).sf() + // case <-dosnapshot.Wait(a.ctx): + // return u(func(snapper *Snapper) { + // snapper.state = Planning + // }).sf() case <-a.ctx.Done(): return onMainCtxDone(a.ctx, u) } diff --git a/main.go b/main.go index 3f7625f..24f27c0 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,8 @@ import ( func init() { cli.AddSubcommand(daemon.DaemonCmd) cli.AddSubcommand(status.Subcommand) - cli.AddSubcommand(client.SignalCmd) + cli.AddSubcommand(client.TriggerCmd) + cli.AddSubcommand(client.ResetCmd) cli.AddSubcommand(client.WaitCmd) cli.AddSubcommand(client.StdinserverCmd) cli.AddSubcommand(client.ConfigcheckCmd)