mirror of
https://github.com/zrepl/zrepl.git
synced 2025-02-17 19:01:12 +01:00
WIP: extract logic for WaitForTrigger, Trigger, Reset, Poll into a reusable abstraction
This commit is contained in:
parent
3e6cae1c8f
commit
f28676b8d7
@ -440,58 +440,67 @@ func (j *ActiveSide) Run(ctx context.Context) {
|
|||||||
j.reset = make(chan uint64)
|
j.reset = make(chan uint64)
|
||||||
j.nextInvocationId = 1
|
j.nextInvocationId = 1
|
||||||
|
|
||||||
outer:
|
type WaitTriggerResult interface {
|
||||||
for {
|
|
||||||
log.Info("wait for replications")
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.WithError(ctx.Err()).Info("context")
|
|
||||||
break outer
|
|
||||||
|
|
||||||
case <-j.trigger:
|
}
|
||||||
j.mode.ResetConnectBackoff()
|
var t interface{
|
||||||
case <-periodicDone:
|
WaitForTrigger(context.Context) (context.Context, <-chan WaitTriggerResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
log.Info("wait for triggers")
|
||||||
|
|
||||||
|
// j.tasksMtx.Lock()
|
||||||
|
// j.activeInvocationId = j.nextInvocationId
|
||||||
|
// j.nextInvocationId++
|
||||||
|
// thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below
|
||||||
|
// j.tasksMtx.Unlock()
|
||||||
|
|
||||||
|
// // setup the goroutine that waits for task resets
|
||||||
|
// // Task resets are converted into cancellations of the invocation context.
|
||||||
|
|
||||||
|
invocationCtx, cancelInvocation := context.WithCancel(invocationCtx)
|
||||||
|
// waitForResetCtx, stopWaitForReset := context.WithCancel(ctx)
|
||||||
|
// var wg sync.WaitGroup
|
||||||
|
// wg.Add(1)
|
||||||
|
// go func() {
|
||||||
|
// defer wg.Done()
|
||||||
|
// select {
|
||||||
|
// case <-waitForResetCtx.Done():
|
||||||
|
// return
|
||||||
|
// case reqResetInvocation := <-j.reset:
|
||||||
|
// l := log.WithField("requested_invocation_id", reqResetInvocation).
|
||||||
|
// WithField("this_invocation_id", thisInvocation)
|
||||||
|
// if reqResetInvocation == thisInvocation {
|
||||||
|
// l.Info("reset received, cancelling current invocation")
|
||||||
|
// cancelInvocation()
|
||||||
|
// } else {
|
||||||
|
// l.Debug("received reset for invocation id that is not us, discarding request")
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
|
||||||
|
// j.tasksMtx.Lock()
|
||||||
|
// j.activeInvocationId = 0
|
||||||
|
// j.tasksMtx.Unlock()
|
||||||
|
|
||||||
|
invocationCtx, err := t.WaitForTrigger(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(ctx.Err()).Info("error waiting for trigger")
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
j.tasksMtx.Lock()
|
j.mode.ResetConnectBackoff()
|
||||||
j.activeInvocationId = j.nextInvocationId
|
|
||||||
j.nextInvocationId++
|
|
||||||
thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below
|
|
||||||
j.tasksMtx.Unlock()
|
|
||||||
|
|
||||||
// setup the invocation context
|
// setup the invocation context
|
||||||
invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
|
invocationCtx, endSpan := trace.WithSpan(invocationCtx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
|
||||||
invocationCtx, cancelInvocation := context.WithCancel(invocationCtx)
|
|
||||||
|
|
||||||
// setup the goroutine that waits for task resets
|
|
||||||
// Task resets are converted into cancellations of the invocation context.
|
|
||||||
waitForResetCtx, stopWaitForReset := context.WithCancel(ctx)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
select {
|
|
||||||
case <-waitForResetCtx.Done():
|
|
||||||
return
|
|
||||||
case reqResetInvocation := <-j.reset:
|
|
||||||
l := log.WithField("requested_invocation_id", reqResetInvocation).
|
|
||||||
WithField("this_invocation_id", thisInvocation)
|
|
||||||
if reqResetInvocation == thisInvocation {
|
|
||||||
l.Info("reset received, cancelling current invocation")
|
|
||||||
cancelInvocation()
|
|
||||||
} else {
|
|
||||||
l.Debug("received reset for invocation id that is not us, discarding request")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
j.do(invocationCtx)
|
j.do(invocationCtx)
|
||||||
stopWaitForReset()
|
stopWaitForReset()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
j.tasksMtx.Lock()
|
|
||||||
j.activeInvocationId = 0
|
|
||||||
j.tasksMtx.Unlock()
|
|
||||||
|
|
||||||
endSpan()
|
endSpan()
|
||||||
}
|
}
|
||||||
@ -566,11 +575,11 @@ func (j *ActiveSide) Reset(req ActiveSideResetRequest) (*ActiveSideResetResponse
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ActiveSideTriggerRequest struct {
|
type ActiveSideTriggerRequest struct {
|
||||||
What string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ActiveSideSignalResponse struct {
|
type ActiveSideSignalResponse struct {
|
||||||
InvocationId uint64
|
InvocationId uint64
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalResponse, error) {
|
func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalResponse, error) {
|
||||||
@ -585,25 +594,20 @@ func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalRes
|
|||||||
// err = fmt.Errorf("operation %q is invalid", req.Op)
|
// err = fmt.Errorf("operation %q is invalid", req.Op)
|
||||||
// }
|
// }
|
||||||
|
|
||||||
switch req.What {
|
j.tasksMtx.Lock()
|
||||||
case "invocation":
|
var invocationId uint64
|
||||||
j.tasksMtx.Lock()
|
if j.activeInvocationId != 0 {
|
||||||
var invocationId uint64
|
invocationId = j.activeInvocationId
|
||||||
if j.activeInvocationId != 0 {
|
} else {
|
||||||
invocationId = j.activeInvocationId
|
invocationId = j.nextInvocationId
|
||||||
} else {
|
|
||||||
invocationId = j.nextInvocationId
|
|
||||||
}
|
|
||||||
// non-blocking send (.Run() must not hold mutex while waiting for signals)
|
|
||||||
select {
|
|
||||||
case j.trigger <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
j.tasksMtx.Unlock()
|
|
||||||
return &ActiveSideSignalResponse{InvocationId: invocationId}, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown signal %q", req.What)
|
|
||||||
}
|
}
|
||||||
|
// non-blocking send (.Run() must not hold mutex while waiting for signals)
|
||||||
|
select {
|
||||||
|
case j.trigger <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
j.tasksMtx.Unlock()
|
||||||
|
return &ActiveSideSignalResponse{InvocationId: invocationId}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *ActiveSide) do(ctx context.Context) {
|
func (j *ActiveSide) do(ctx context.Context) {
|
||||||
|
187
daemon/job/trigger/trigger.go
Normal file
187
daemon/job/trigger/trigger.go
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
//
|
||||||
|
//
|
||||||
|
// Alternative Design (in "RustGo")
|
||||||
|
//
|
||||||
|
// enum InternalMsg {
|
||||||
|
// Trigger((), chan (TriggerResponse, error)),
|
||||||
|
// Poll(PollRequest, chan PollResponse),
|
||||||
|
// Reset(ResetRequest, chan (ResetResponse, error)),
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// enum State {
|
||||||
|
// Running{
|
||||||
|
// invocationId: u32,
|
||||||
|
// cancelCurrentInvocation: context.CancelFunc
|
||||||
|
// }
|
||||||
|
// Waiting{
|
||||||
|
// nextInvocationId: u32,
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// for msg := <- t.internalMsgs {
|
||||||
|
// match (msg, state) {
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
package trigger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/zrepl/zrepl/daemon/logging"
|
||||||
|
"github.com/zrepl/zrepl/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type T struct {
|
||||||
|
mtx sync.Mutex
|
||||||
|
cv sync.Cond
|
||||||
|
|
||||||
|
nextInvocationId uint64
|
||||||
|
activeInvocationId uint64 // 0 <=> inactive
|
||||||
|
triggerPending bool
|
||||||
|
contextDone bool
|
||||||
|
reset chan uint64
|
||||||
|
stopWaitForReset chan struct{}
|
||||||
|
cancelCurrentInvocation context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func New() *T {
|
||||||
|
t := &T{
|
||||||
|
activeInvocationId: math.MaxUint64,
|
||||||
|
nextInvocationId: 1,
|
||||||
|
}
|
||||||
|
t.cv.L = &t.mtx
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *T) WaitForTrigger(ctx context.Context) (rctx context.Context, err error) {
|
||||||
|
t.mtx.Lock()
|
||||||
|
defer t.mtx.Unlock()
|
||||||
|
|
||||||
|
if t.activeInvocationId == 0 {
|
||||||
|
return nil, fmt.Errorf("must be running when calling this function")
|
||||||
|
}
|
||||||
|
t.activeInvocationId = 0
|
||||||
|
t.cancelCurrentInvocation = nil
|
||||||
|
|
||||||
|
if t.contextDone == true {
|
||||||
|
panic("implementation error: this variable is only true while in WaitForTrigger, and that's a mutually exclusive function")
|
||||||
|
}
|
||||||
|
stopWaitingForDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-stopWaitingForDone:
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.mtx.Lock()
|
||||||
|
t.contextDone = true
|
||||||
|
t.cv.Broadcast()
|
||||||
|
t.mtx.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
t.triggerPending = false
|
||||||
|
t.contextDone = false
|
||||||
|
}()
|
||||||
|
for !t.triggerPending && !t.contextDone {
|
||||||
|
t.cv.Wait()
|
||||||
|
}
|
||||||
|
close(stopWaitingForDone)
|
||||||
|
if t.contextDone {
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
panic("implementation error: contextDone <=> ctx.Err() != nil")
|
||||||
|
}
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
t.activeInvocationId = t.nextInvocationId
|
||||||
|
t.nextInvocationId++
|
||||||
|
rctx, t.cancelCurrentInvocation = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
return rctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TriggerResponse struct {
|
||||||
|
InvocationId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *T) Trigger() (TriggerResponse, error) {
|
||||||
|
t.mtx.Lock()
|
||||||
|
defer t.mtx.Unlock()
|
||||||
|
var invocationId uint64
|
||||||
|
if t.activeInvocationId != 0 {
|
||||||
|
invocationId = t.activeInvocationId
|
||||||
|
} else {
|
||||||
|
invocationId = t.nextInvocationId
|
||||||
|
}
|
||||||
|
// non-blocking send (.Run() must not hold mutex while waiting for signals)
|
||||||
|
t.triggerPending = true
|
||||||
|
t.cv.Broadcast()
|
||||||
|
return TriggerResponse{InvocationId: invocationId}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PollRequest struct {
|
||||||
|
InvocationId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type PollResponse struct {
|
||||||
|
Done bool
|
||||||
|
InvocationId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *T) Poll(req PollRequest) (res PollResponse) {
|
||||||
|
t.mtx.Lock()
|
||||||
|
defer t.mtx.Unlock()
|
||||||
|
|
||||||
|
waitForId := req.InvocationId
|
||||||
|
if req.InvocationId == 0 {
|
||||||
|
// handle the case where the client doesn't know what the current invocation id is
|
||||||
|
if t.activeInvocationId != 0 {
|
||||||
|
waitForId = t.activeInvocationId
|
||||||
|
} else {
|
||||||
|
waitForId = t.nextInvocationId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var done bool
|
||||||
|
if t.activeInvocationId == 0 {
|
||||||
|
done = waitForId < t.nextInvocationId
|
||||||
|
} else {
|
||||||
|
done = waitForId < t.activeInvocationId
|
||||||
|
}
|
||||||
|
return PollResponse{Done: done, InvocationId: waitForId}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResetRequest struct {
|
||||||
|
InvocationId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResetResponse struct {
|
||||||
|
InvocationId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *T) Reset(req ResetRequest) (*ResetResponse, error) {
|
||||||
|
t.mtx.Lock()
|
||||||
|
defer t.mtx.Unlock()
|
||||||
|
|
||||||
|
resetId := req.InvocationId
|
||||||
|
if req.InvocationId == 0 {
|
||||||
|
// handle the case where the client doesn't know what the current invocation id is
|
||||||
|
resetId = t.activeInvocationId
|
||||||
|
}
|
||||||
|
|
||||||
|
if resetId == 0 {
|
||||||
|
return nil, fmt.Errorf("no active invocation")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resetId != t.activeInvocationId {
|
||||||
|
return nil, fmt.Errorf("active invocation (%d) is not the invocation requested for reset (%d); (active invocation '0' indicates no active invocation)", t.activeInvocationId, resetId)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.cancelCurrentInvocation()
|
||||||
|
|
||||||
|
return &ResetResponse{InvocationId: resetId}, nil
|
||||||
|
}
|
84
daemon/job/trigger/trigger_test.go
Normal file
84
daemon/job/trigger/trigger_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package trigger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBasics(t *testing.T) {
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
tr := New()
|
||||||
|
|
||||||
|
triggered := make(chan int)
|
||||||
|
waitForTriggerError := make(chan error)
|
||||||
|
waitForResetCallToBeMadeByMainGoroutine := make(chan struct{})
|
||||||
|
postResetAssertionsDone := make(chan struct{})
|
||||||
|
|
||||||
|
taskCtx := context.Background()
|
||||||
|
taskCtx, cancelTaskCtx := context.WithCancel(taskCtx)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
taskCtx := context.WithValue(taskCtx, "mykey", "myvalue")
|
||||||
|
|
||||||
|
triggers := 0
|
||||||
|
|
||||||
|
outer:
|
||||||
|
for {
|
||||||
|
invocationCtx, err := tr.WaitForTrigger(taskCtx)
|
||||||
|
if err != nil {
|
||||||
|
waitForTriggerError <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.Equal(t, invocationCtx.Value("mykey"), "myvalue")
|
||||||
|
|
||||||
|
triggers++
|
||||||
|
triggered <- triggers
|
||||||
|
|
||||||
|
switch triggers {
|
||||||
|
case 1:
|
||||||
|
continue outer
|
||||||
|
case 2:
|
||||||
|
<-waitForResetCallToBeMadeByMainGoroutine
|
||||||
|
require.Equal(t, context.Canceled, invocationCtx.Err(), "Reset() cancels invocation context")
|
||||||
|
require.Nil(t, taskCtx.Err(), "Reset() does not cancel task context")
|
||||||
|
close(postResetAssertionsDone)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Logf("trigger 1")
|
||||||
|
_, err := tr.Trigger()
|
||||||
|
require.NoError(t, err)
|
||||||
|
v := <-triggered
|
||||||
|
require.Equal(t, 1, v)
|
||||||
|
|
||||||
|
t.Logf("trigger 2")
|
||||||
|
triggerResponse, err := tr.Trigger()
|
||||||
|
require.NoError(t, err)
|
||||||
|
v = <-triggered
|
||||||
|
require.Equal(t, 2, v)
|
||||||
|
|
||||||
|
t.Logf("reset")
|
||||||
|
resetResponse, err := tr.Reset(ResetRequest{InvocationId: triggerResponse.InvocationId})
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("reset response: %#v", resetResponse)
|
||||||
|
close(waitForResetCallToBeMadeByMainGoroutine)
|
||||||
|
<-postResetAssertionsDone
|
||||||
|
|
||||||
|
t.Logf("cancel the context")
|
||||||
|
cancelTaskCtx()
|
||||||
|
wfte := <-waitForTriggerError
|
||||||
|
require.Equal(t, taskCtx.Err(), wfte)
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user