mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-24 17:35:01 +01:00
WIP: extract logic for WaitForTrigger, Trigger, Reset, Poll into a reusable abstraction
This commit is contained in:
parent
3e6cae1c8f
commit
f28676b8d7
@ -440,58 +440,67 @@ func (j *ActiveSide) Run(ctx context.Context) {
|
||||
j.reset = make(chan uint64)
|
||||
j.nextInvocationId = 1
|
||||
|
||||
outer:
|
||||
for {
|
||||
log.Info("wait for replications")
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.WithError(ctx.Err()).Info("context")
|
||||
break outer
|
||||
type WaitTriggerResult interface {
|
||||
|
||||
case <-j.trigger:
|
||||
j.mode.ResetConnectBackoff()
|
||||
case <-periodicDone:
|
||||
}
|
||||
var t interface{
|
||||
WaitForTrigger(context.Context) (context.Context, <-chan WaitTriggerResult)
|
||||
}
|
||||
|
||||
for {
|
||||
log.Info("wait for triggers")
|
||||
|
||||
// j.tasksMtx.Lock()
|
||||
// j.activeInvocationId = j.nextInvocationId
|
||||
// j.nextInvocationId++
|
||||
// thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below
|
||||
// j.tasksMtx.Unlock()
|
||||
|
||||
// // setup the goroutine that waits for task resets
|
||||
// // Task resets are converted into cancellations of the invocation context.
|
||||
|
||||
invocationCtx, cancelInvocation := context.WithCancel(invocationCtx)
|
||||
// waitForResetCtx, stopWaitForReset := context.WithCancel(ctx)
|
||||
// var wg sync.WaitGroup
|
||||
// wg.Add(1)
|
||||
// go func() {
|
||||
// defer wg.Done()
|
||||
// select {
|
||||
// case <-waitForResetCtx.Done():
|
||||
// return
|
||||
// case reqResetInvocation := <-j.reset:
|
||||
// l := log.WithField("requested_invocation_id", reqResetInvocation).
|
||||
// WithField("this_invocation_id", thisInvocation)
|
||||
// if reqResetInvocation == thisInvocation {
|
||||
// l.Info("reset received, cancelling current invocation")
|
||||
// cancelInvocation()
|
||||
// } else {
|
||||
// l.Debug("received reset for invocation id that is not us, discarding request")
|
||||
// }
|
||||
// }
|
||||
// }()
|
||||
|
||||
// j.tasksMtx.Lock()
|
||||
// j.activeInvocationId = 0
|
||||
// j.tasksMtx.Unlock()
|
||||
|
||||
invocationCtx, err := t.WaitForTrigger(ctx)
|
||||
if err != nil {
|
||||
log.WithError(ctx.Err()).Info("error waiting for trigger")
|
||||
break
|
||||
}
|
||||
|
||||
j.tasksMtx.Lock()
|
||||
j.activeInvocationId = j.nextInvocationId
|
||||
j.nextInvocationId++
|
||||
thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below
|
||||
j.tasksMtx.Unlock()
|
||||
j.mode.ResetConnectBackoff()
|
||||
|
||||
// setup the invocation context
|
||||
invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
|
||||
invocationCtx, cancelInvocation := context.WithCancel(invocationCtx)
|
||||
invocationCtx, endSpan := trace.WithSpan(invocationCtx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
|
||||
|
||||
// setup the goroutine that waits for task resets
|
||||
// Task resets are converted into cancellations of the invocation context.
|
||||
waitForResetCtx, stopWaitForReset := context.WithCancel(ctx)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
select {
|
||||
case <-waitForResetCtx.Done():
|
||||
return
|
||||
case reqResetInvocation := <-j.reset:
|
||||
l := log.WithField("requested_invocation_id", reqResetInvocation).
|
||||
WithField("this_invocation_id", thisInvocation)
|
||||
if reqResetInvocation == thisInvocation {
|
||||
l.Info("reset received, cancelling current invocation")
|
||||
cancelInvocation()
|
||||
} else {
|
||||
l.Debug("received reset for invocation id that is not us, discarding request")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
j.do(invocationCtx)
|
||||
stopWaitForReset()
|
||||
wg.Wait()
|
||||
|
||||
j.tasksMtx.Lock()
|
||||
j.activeInvocationId = 0
|
||||
j.tasksMtx.Unlock()
|
||||
|
||||
|
||||
endSpan()
|
||||
}
|
||||
@ -566,11 +575,11 @@ func (j *ActiveSide) Reset(req ActiveSideResetRequest) (*ActiveSideResetResponse
|
||||
}
|
||||
|
||||
type ActiveSideTriggerRequest struct {
|
||||
What string
|
||||
}
|
||||
|
||||
type ActiveSideSignalResponse struct {
|
||||
InvocationId uint64
|
||||
|
||||
}
|
||||
|
||||
func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalResponse, error) {
|
||||
@ -585,25 +594,20 @@ func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalRes
|
||||
// err = fmt.Errorf("operation %q is invalid", req.Op)
|
||||
// }
|
||||
|
||||
switch req.What {
|
||||
case "invocation":
|
||||
j.tasksMtx.Lock()
|
||||
var invocationId uint64
|
||||
if j.activeInvocationId != 0 {
|
||||
invocationId = j.activeInvocationId
|
||||
} else {
|
||||
invocationId = j.nextInvocationId
|
||||
}
|
||||
// non-blocking send (.Run() must not hold mutex while waiting for signals)
|
||||
select {
|
||||
case j.trigger <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
j.tasksMtx.Unlock()
|
||||
return &ActiveSideSignalResponse{InvocationId: invocationId}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown signal %q", req.What)
|
||||
j.tasksMtx.Lock()
|
||||
var invocationId uint64
|
||||
if j.activeInvocationId != 0 {
|
||||
invocationId = j.activeInvocationId
|
||||
} else {
|
||||
invocationId = j.nextInvocationId
|
||||
}
|
||||
// non-blocking send (.Run() must not hold mutex while waiting for signals)
|
||||
select {
|
||||
case j.trigger <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
j.tasksMtx.Unlock()
|
||||
return &ActiveSideSignalResponse{InvocationId: invocationId}, nil
|
||||
}
|
||||
|
||||
func (j *ActiveSide) do(ctx context.Context) {
|
||||
|
187
daemon/job/trigger/trigger.go
Normal file
187
daemon/job/trigger/trigger.go
Normal file
@ -0,0 +1,187 @@
|
||||
//
|
||||
//
|
||||
// Alternative Design (in "RustGo")
|
||||
//
|
||||
// enum InternalMsg {
|
||||
// Trigger((), chan (TriggerResponse, error)),
|
||||
// Poll(PollRequest, chan PollResponse),
|
||||
// Reset(ResetRequest, chan (ResetResponse, error)),
|
||||
// }
|
||||
//
|
||||
// enum State {
|
||||
// Running{
|
||||
// invocationId: u32,
|
||||
// cancelCurrentInvocation: context.CancelFunc
|
||||
// }
|
||||
// Waiting{
|
||||
// nextInvocationId: u32,
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// for msg := <- t.internalMsgs {
|
||||
// match (msg, state) {
|
||||
// ...
|
||||
// }
|
||||
// }
|
||||
package trigger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/zrepl/zrepl/daemon/logging"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
)
|
||||
|
||||
type T struct {
|
||||
mtx sync.Mutex
|
||||
cv sync.Cond
|
||||
|
||||
nextInvocationId uint64
|
||||
activeInvocationId uint64 // 0 <=> inactive
|
||||
triggerPending bool
|
||||
contextDone bool
|
||||
reset chan uint64
|
||||
stopWaitForReset chan struct{}
|
||||
cancelCurrentInvocation context.CancelFunc
|
||||
}
|
||||
|
||||
func New() *T {
|
||||
t := &T{
|
||||
activeInvocationId: math.MaxUint64,
|
||||
nextInvocationId: 1,
|
||||
}
|
||||
t.cv.L = &t.mtx
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *T) WaitForTrigger(ctx context.Context) (rctx context.Context, err error) {
|
||||
t.mtx.Lock()
|
||||
defer t.mtx.Unlock()
|
||||
|
||||
if t.activeInvocationId == 0 {
|
||||
return nil, fmt.Errorf("must be running when calling this function")
|
||||
}
|
||||
t.activeInvocationId = 0
|
||||
t.cancelCurrentInvocation = nil
|
||||
|
||||
if t.contextDone == true {
|
||||
panic("implementation error: this variable is only true while in WaitForTrigger, and that's a mutually exclusive function")
|
||||
}
|
||||
stopWaitingForDone := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-stopWaitingForDone:
|
||||
case <-ctx.Done():
|
||||
t.mtx.Lock()
|
||||
t.contextDone = true
|
||||
t.cv.Broadcast()
|
||||
t.mtx.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
t.triggerPending = false
|
||||
t.contextDone = false
|
||||
}()
|
||||
for !t.triggerPending && !t.contextDone {
|
||||
t.cv.Wait()
|
||||
}
|
||||
close(stopWaitingForDone)
|
||||
if t.contextDone {
|
||||
if ctx.Err() == nil {
|
||||
panic("implementation error: contextDone <=> ctx.Err() != nil")
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
t.activeInvocationId = t.nextInvocationId
|
||||
t.nextInvocationId++
|
||||
rctx, t.cancelCurrentInvocation = context.WithCancel(ctx)
|
||||
|
||||
return rctx, nil
|
||||
}
|
||||
|
||||
type TriggerResponse struct {
|
||||
InvocationId uint64
|
||||
}
|
||||
|
||||
func (t *T) Trigger() (TriggerResponse, error) {
|
||||
t.mtx.Lock()
|
||||
defer t.mtx.Unlock()
|
||||
var invocationId uint64
|
||||
if t.activeInvocationId != 0 {
|
||||
invocationId = t.activeInvocationId
|
||||
} else {
|
||||
invocationId = t.nextInvocationId
|
||||
}
|
||||
// non-blocking send (.Run() must not hold mutex while waiting for signals)
|
||||
t.triggerPending = true
|
||||
t.cv.Broadcast()
|
||||
return TriggerResponse{InvocationId: invocationId}, nil
|
||||
}
|
||||
|
||||
type PollRequest struct {
|
||||
InvocationId uint64
|
||||
}
|
||||
|
||||
type PollResponse struct {
|
||||
Done bool
|
||||
InvocationId uint64
|
||||
}
|
||||
|
||||
func (t *T) Poll(req PollRequest) (res PollResponse) {
|
||||
t.mtx.Lock()
|
||||
defer t.mtx.Unlock()
|
||||
|
||||
waitForId := req.InvocationId
|
||||
if req.InvocationId == 0 {
|
||||
// handle the case where the client doesn't know what the current invocation id is
|
||||
if t.activeInvocationId != 0 {
|
||||
waitForId = t.activeInvocationId
|
||||
} else {
|
||||
waitForId = t.nextInvocationId
|
||||
}
|
||||
}
|
||||
|
||||
var done bool
|
||||
if t.activeInvocationId == 0 {
|
||||
done = waitForId < t.nextInvocationId
|
||||
} else {
|
||||
done = waitForId < t.activeInvocationId
|
||||
}
|
||||
return PollResponse{Done: done, InvocationId: waitForId}
|
||||
}
|
||||
|
||||
type ResetRequest struct {
|
||||
InvocationId uint64
|
||||
}
|
||||
|
||||
type ResetResponse struct {
|
||||
InvocationId uint64
|
||||
}
|
||||
|
||||
func (t *T) Reset(req ResetRequest) (*ResetResponse, error) {
|
||||
t.mtx.Lock()
|
||||
defer t.mtx.Unlock()
|
||||
|
||||
resetId := req.InvocationId
|
||||
if req.InvocationId == 0 {
|
||||
// handle the case where the client doesn't know what the current invocation id is
|
||||
resetId = t.activeInvocationId
|
||||
}
|
||||
|
||||
if resetId == 0 {
|
||||
return nil, fmt.Errorf("no active invocation")
|
||||
}
|
||||
|
||||
if resetId != t.activeInvocationId {
|
||||
return nil, fmt.Errorf("active invocation (%d) is not the invocation requested for reset (%d); (active invocation '0' indicates no active invocation)", t.activeInvocationId, resetId)
|
||||
}
|
||||
|
||||
t.cancelCurrentInvocation()
|
||||
|
||||
return &ResetResponse{InvocationId: resetId}, nil
|
||||
}
|
84
daemon/job/trigger/trigger_test.go
Normal file
84
daemon/job/trigger/trigger_test.go
Normal file
@ -0,0 +1,84 @@
|
||||
package trigger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasics(t *testing.T) {
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
tr := New()
|
||||
|
||||
triggered := make(chan int)
|
||||
waitForTriggerError := make(chan error)
|
||||
waitForResetCallToBeMadeByMainGoroutine := make(chan struct{})
|
||||
postResetAssertionsDone := make(chan struct{})
|
||||
|
||||
taskCtx := context.Background()
|
||||
taskCtx, cancelTaskCtx := context.WithCancel(taskCtx)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
taskCtx := context.WithValue(taskCtx, "mykey", "myvalue")
|
||||
|
||||
triggers := 0
|
||||
|
||||
outer:
|
||||
for {
|
||||
invocationCtx, err := tr.WaitForTrigger(taskCtx)
|
||||
if err != nil {
|
||||
waitForTriggerError <- err
|
||||
return
|
||||
}
|
||||
require.Equal(t, invocationCtx.Value("mykey"), "myvalue")
|
||||
|
||||
triggers++
|
||||
triggered <- triggers
|
||||
|
||||
switch triggers {
|
||||
case 1:
|
||||
continue outer
|
||||
case 2:
|
||||
<-waitForResetCallToBeMadeByMainGoroutine
|
||||
require.Equal(t, context.Canceled, invocationCtx.Err(), "Reset() cancels invocation context")
|
||||
require.Nil(t, taskCtx.Err(), "Reset() does not cancel task context")
|
||||
close(postResetAssertionsDone)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
t.Logf("trigger 1")
|
||||
_, err := tr.Trigger()
|
||||
require.NoError(t, err)
|
||||
v := <-triggered
|
||||
require.Equal(t, 1, v)
|
||||
|
||||
t.Logf("trigger 2")
|
||||
triggerResponse, err := tr.Trigger()
|
||||
require.NoError(t, err)
|
||||
v = <-triggered
|
||||
require.Equal(t, 2, v)
|
||||
|
||||
t.Logf("reset")
|
||||
resetResponse, err := tr.Reset(ResetRequest{InvocationId: triggerResponse.InvocationId})
|
||||
require.NoError(t, err)
|
||||
t.Logf("reset response: %#v", resetResponse)
|
||||
close(waitForResetCallToBeMadeByMainGoroutine)
|
||||
<-postResetAssertionsDone
|
||||
|
||||
t.Logf("cancel the context")
|
||||
cancelTaskCtx()
|
||||
wfte := <-waitForTriggerError
|
||||
require.Equal(t, taskCtx.Err(), wfte)
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user