From f28676b8d7e1e92ea3eb72f36b4a17885ee91879 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Wed, 24 Mar 2021 00:02:10 +0100 Subject: [PATCH] WIP: extract logic for WaitForTrigger, Trigger, Reset, Poll into a reusable abstraction --- daemon/job/active.go | 126 +++++++++---------- daemon/job/trigger/trigger.go | 187 +++++++++++++++++++++++++++++ daemon/job/trigger/trigger_test.go | 84 +++++++++++++ 3 files changed, 336 insertions(+), 61 deletions(-) create mode 100644 daemon/job/trigger/trigger.go create mode 100644 daemon/job/trigger/trigger_test.go diff --git a/daemon/job/active.go b/daemon/job/active.go index e9f4b0b..0462f85 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -440,58 +440,67 @@ func (j *ActiveSide) Run(ctx context.Context) { j.reset = make(chan uint64) j.nextInvocationId = 1 -outer: - for { - log.Info("wait for replications") - select { - case <-ctx.Done(): - log.WithError(ctx.Err()).Info("context") - break outer + type WaitTriggerResult interface { + + } + var t interface{ + WaitForTrigger(context.Context) (context.Context, <-chan WaitTriggerResult) + } - case <-j.trigger: - j.mode.ResetConnectBackoff() - case <-periodicDone: + for { + log.Info("wait for triggers") + + // j.tasksMtx.Lock() + // j.activeInvocationId = j.nextInvocationId + // j.nextInvocationId++ + // thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below + // j.tasksMtx.Unlock() + + // // setup the goroutine that waits for task resets + // // Task resets are converted into cancellations of the invocation context. + + invocationCtx, cancelInvocation := context.WithCancel(invocationCtx) + // waitForResetCtx, stopWaitForReset := context.WithCancel(ctx) + // var wg sync.WaitGroup + // wg.Add(1) + // go func() { + // defer wg.Done() + // select { + // case <-waitForResetCtx.Done(): + // return + // case reqResetInvocation := <-j.reset: + // l := log.WithField("requested_invocation_id", reqResetInvocation). + // WithField("this_invocation_id", thisInvocation) + // if reqResetInvocation == thisInvocation { + // l.Info("reset received, cancelling current invocation") + // cancelInvocation() + // } else { + // l.Debug("received reset for invocation id that is not us, discarding request") + // } + // } + // }() + + // j.tasksMtx.Lock() + // j.activeInvocationId = 0 + // j.tasksMtx.Unlock() + + invocationCtx, err := t.WaitForTrigger(ctx) + if err != nil { + log.WithError(ctx.Err()).Info("error waiting for trigger") + break } - j.tasksMtx.Lock() - j.activeInvocationId = j.nextInvocationId - j.nextInvocationId++ - thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below - j.tasksMtx.Unlock() + j.mode.ResetConnectBackoff() // setup the invocation context - invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId)) - invocationCtx, cancelInvocation := context.WithCancel(invocationCtx) - - // setup the goroutine that waits for task resets - // Task resets are converted into cancellations of the invocation context. - waitForResetCtx, stopWaitForReset := context.WithCancel(ctx) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - select { - case <-waitForResetCtx.Done(): - return - case reqResetInvocation := <-j.reset: - l := log.WithField("requested_invocation_id", reqResetInvocation). - WithField("this_invocation_id", thisInvocation) - if reqResetInvocation == thisInvocation { - l.Info("reset received, cancelling current invocation") - cancelInvocation() - } else { - l.Debug("received reset for invocation id that is not us, discarding request") - } - } - }() + invocationCtx, endSpan := trace.WithSpan(invocationCtx, fmt.Sprintf("invocation-%d", j.nextInvocationId)) + j.do(invocationCtx) stopWaitForReset() wg.Wait() - j.tasksMtx.Lock() - j.activeInvocationId = 0 - j.tasksMtx.Unlock() + endSpan() } @@ -566,11 +575,11 @@ func (j *ActiveSide) Reset(req ActiveSideResetRequest) (*ActiveSideResetResponse } type ActiveSideTriggerRequest struct { - What string } type ActiveSideSignalResponse struct { InvocationId uint64 + } func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalResponse, error) { @@ -585,25 +594,20 @@ func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalRes // err = fmt.Errorf("operation %q is invalid", req.Op) // } - switch req.What { - case "invocation": - j.tasksMtx.Lock() - var invocationId uint64 - if j.activeInvocationId != 0 { - invocationId = j.activeInvocationId - } else { - invocationId = j.nextInvocationId - } - // non-blocking send (.Run() must not hold mutex while waiting for signals) - select { - case j.trigger <- struct{}{}: - default: - } - j.tasksMtx.Unlock() - return &ActiveSideSignalResponse{InvocationId: invocationId}, nil - default: - return nil, fmt.Errorf("unknown signal %q", req.What) + j.tasksMtx.Lock() + var invocationId uint64 + if j.activeInvocationId != 0 { + invocationId = j.activeInvocationId + } else { + invocationId = j.nextInvocationId } + // non-blocking send (.Run() must not hold mutex while waiting for signals) + select { + case j.trigger <- struct{}{}: + default: + } + j.tasksMtx.Unlock() + return &ActiveSideSignalResponse{InvocationId: invocationId}, nil } func (j *ActiveSide) do(ctx context.Context) { diff --git a/daemon/job/trigger/trigger.go b/daemon/job/trigger/trigger.go new file mode 100644 index 0000000..5aea02b --- /dev/null +++ b/daemon/job/trigger/trigger.go @@ -0,0 +1,187 @@ +// +// +// Alternative Design (in "RustGo") +// +// enum InternalMsg { +// Trigger((), chan (TriggerResponse, error)), +// Poll(PollRequest, chan PollResponse), +// Reset(ResetRequest, chan (ResetResponse, error)), +// } +// +// enum State { +// Running{ +// invocationId: u32, +// cancelCurrentInvocation: context.CancelFunc +// } +// Waiting{ +// nextInvocationId: u32, +// } +// } +// +// for msg := <- t.internalMsgs { +// match (msg, state) { +// ... +// } +// } +package trigger + +import ( + "context" + "fmt" + "math" + "sync" + + "github.com/zrepl/zrepl/daemon/logging" + "github.com/zrepl/zrepl/logger" +) + +type T struct { + mtx sync.Mutex + cv sync.Cond + + nextInvocationId uint64 + activeInvocationId uint64 // 0 <=> inactive + triggerPending bool + contextDone bool + reset chan uint64 + stopWaitForReset chan struct{} + cancelCurrentInvocation context.CancelFunc +} + +func New() *T { + t := &T{ + activeInvocationId: math.MaxUint64, + nextInvocationId: 1, + } + t.cv.L = &t.mtx + return t +} + +func (t *T) WaitForTrigger(ctx context.Context) (rctx context.Context, err error) { + t.mtx.Lock() + defer t.mtx.Unlock() + + if t.activeInvocationId == 0 { + return nil, fmt.Errorf("must be running when calling this function") + } + t.activeInvocationId = 0 + t.cancelCurrentInvocation = nil + + if t.contextDone == true { + panic("implementation error: this variable is only true while in WaitForTrigger, and that's a mutually exclusive function") + } + stopWaitingForDone := make(chan struct{}) + go func() { + select { + case <-stopWaitingForDone: + case <-ctx.Done(): + t.mtx.Lock() + t.contextDone = true + t.cv.Broadcast() + t.mtx.Unlock() + } + }() + + defer func() { + t.triggerPending = false + t.contextDone = false + }() + for !t.triggerPending && !t.contextDone { + t.cv.Wait() + } + close(stopWaitingForDone) + if t.contextDone { + if ctx.Err() == nil { + panic("implementation error: contextDone <=> ctx.Err() != nil") + } + return nil, ctx.Err() + } + + t.activeInvocationId = t.nextInvocationId + t.nextInvocationId++ + rctx, t.cancelCurrentInvocation = context.WithCancel(ctx) + + return rctx, nil +} + +type TriggerResponse struct { + InvocationId uint64 +} + +func (t *T) Trigger() (TriggerResponse, error) { + t.mtx.Lock() + defer t.mtx.Unlock() + var invocationId uint64 + if t.activeInvocationId != 0 { + invocationId = t.activeInvocationId + } else { + invocationId = t.nextInvocationId + } + // non-blocking send (.Run() must not hold mutex while waiting for signals) + t.triggerPending = true + t.cv.Broadcast() + return TriggerResponse{InvocationId: invocationId}, nil +} + +type PollRequest struct { + InvocationId uint64 +} + +type PollResponse struct { + Done bool + InvocationId uint64 +} + +func (t *T) Poll(req PollRequest) (res PollResponse) { + t.mtx.Lock() + defer t.mtx.Unlock() + + waitForId := req.InvocationId + if req.InvocationId == 0 { + // handle the case where the client doesn't know what the current invocation id is + if t.activeInvocationId != 0 { + waitForId = t.activeInvocationId + } else { + waitForId = t.nextInvocationId + } + } + + var done bool + if t.activeInvocationId == 0 { + done = waitForId < t.nextInvocationId + } else { + done = waitForId < t.activeInvocationId + } + return PollResponse{Done: done, InvocationId: waitForId} +} + +type ResetRequest struct { + InvocationId uint64 +} + +type ResetResponse struct { + InvocationId uint64 +} + +func (t *T) Reset(req ResetRequest) (*ResetResponse, error) { + t.mtx.Lock() + defer t.mtx.Unlock() + + resetId := req.InvocationId + if req.InvocationId == 0 { + // handle the case where the client doesn't know what the current invocation id is + resetId = t.activeInvocationId + } + + if resetId == 0 { + return nil, fmt.Errorf("no active invocation") + } + + if resetId != t.activeInvocationId { + return nil, fmt.Errorf("active invocation (%d) is not the invocation requested for reset (%d); (active invocation '0' indicates no active invocation)", t.activeInvocationId, resetId) + } + + t.cancelCurrentInvocation() + + return &ResetResponse{InvocationId: resetId}, nil +} diff --git a/daemon/job/trigger/trigger_test.go b/daemon/job/trigger/trigger_test.go new file mode 100644 index 0000000..0542530 --- /dev/null +++ b/daemon/job/trigger/trigger_test.go @@ -0,0 +1,84 @@ +package trigger + +import ( + "context" + "sync" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBasics(t *testing.T) { + + var wg sync.WaitGroup + defer wg.Wait() + + tr := New() + + triggered := make(chan int) + waitForTriggerError := make(chan error) + waitForResetCallToBeMadeByMainGoroutine := make(chan struct{}) + postResetAssertionsDone := make(chan struct{}) + + taskCtx := context.Background() + taskCtx, cancelTaskCtx := context.WithCancel(taskCtx) + + wg.Add(1) + go func() { + defer wg.Done() + + taskCtx := context.WithValue(taskCtx, "mykey", "myvalue") + + triggers := 0 + + outer: + for { + invocationCtx, err := tr.WaitForTrigger(taskCtx) + if err != nil { + waitForTriggerError <- err + return + } + require.Equal(t, invocationCtx.Value("mykey"), "myvalue") + + triggers++ + triggered <- triggers + + switch triggers { + case 1: + continue outer + case 2: + <-waitForResetCallToBeMadeByMainGoroutine + require.Equal(t, context.Canceled, invocationCtx.Err(), "Reset() cancels invocation context") + require.Nil(t, taskCtx.Err(), "Reset() does not cancel task context") + close(postResetAssertionsDone) + } + + } + + }() + + t.Logf("trigger 1") + _, err := tr.Trigger() + require.NoError(t, err) + v := <-triggered + require.Equal(t, 1, v) + + t.Logf("trigger 2") + triggerResponse, err := tr.Trigger() + require.NoError(t, err) + v = <-triggered + require.Equal(t, 2, v) + + t.Logf("reset") + resetResponse, err := tr.Reset(ResetRequest{InvocationId: triggerResponse.InvocationId}) + require.NoError(t, err) + t.Logf("reset response: %#v", resetResponse) + close(waitForResetCallToBeMadeByMainGoroutine) + <-postResetAssertionsDone + + t.Logf("cancel the context") + cancelTaskCtx() + wfte := <-waitForTriggerError + require.Equal(t, taskCtx.Err(), wfte) + +}