WIP: extract logic for WaitForTrigger, Trigger, Reset, Poll into a reusable abstraction

This commit is contained in:
Christian Schwarz 2021-03-24 00:02:10 +01:00
parent 3e6cae1c8f
commit f28676b8d7
3 changed files with 336 additions and 61 deletions

View File

@ -440,58 +440,67 @@ func (j *ActiveSide) Run(ctx context.Context) {
j.reset = make(chan uint64)
j.nextInvocationId = 1
outer:
for {
log.Info("wait for replications")
select {
case <-ctx.Done():
log.WithError(ctx.Err()).Info("context")
break outer
type WaitTriggerResult interface {
case <-j.trigger:
j.mode.ResetConnectBackoff()
case <-periodicDone:
}
var t interface{
WaitForTrigger(context.Context) (context.Context, <-chan WaitTriggerResult)
}
j.tasksMtx.Lock()
j.activeInvocationId = j.nextInvocationId
j.nextInvocationId++
thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below
j.tasksMtx.Unlock()
for {
log.Info("wait for triggers")
// j.tasksMtx.Lock()
// j.activeInvocationId = j.nextInvocationId
// j.nextInvocationId++
// thisInvocation := j.activeInvocationId // stack-local, for use in reset-handler goroutine below
// j.tasksMtx.Unlock()
// // setup the goroutine that waits for task resets
// // Task resets are converted into cancellations of the invocation context.
invocationCtx, cancelInvocation := context.WithCancel(invocationCtx)
// waitForResetCtx, stopWaitForReset := context.WithCancel(ctx)
// var wg sync.WaitGroup
// wg.Add(1)
// go func() {
// defer wg.Done()
// select {
// case <-waitForResetCtx.Done():
// return
// case reqResetInvocation := <-j.reset:
// l := log.WithField("requested_invocation_id", reqResetInvocation).
// WithField("this_invocation_id", thisInvocation)
// if reqResetInvocation == thisInvocation {
// l.Info("reset received, cancelling current invocation")
// cancelInvocation()
// } else {
// l.Debug("received reset for invocation id that is not us, discarding request")
// }
// }
// }()
// j.tasksMtx.Lock()
// j.activeInvocationId = 0
// j.tasksMtx.Unlock()
invocationCtx, err := t.WaitForTrigger(ctx)
if err != nil {
log.WithError(ctx.Err()).Info("error waiting for trigger")
break
}
j.mode.ResetConnectBackoff()
// setup the invocation context
invocationCtx, endSpan := trace.WithSpan(ctx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
invocationCtx, cancelInvocation := context.WithCancel(invocationCtx)
invocationCtx, endSpan := trace.WithSpan(invocationCtx, fmt.Sprintf("invocation-%d", j.nextInvocationId))
// setup the goroutine that waits for task resets
// Task resets are converted into cancellations of the invocation context.
waitForResetCtx, stopWaitForReset := context.WithCancel(ctx)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-waitForResetCtx.Done():
return
case reqResetInvocation := <-j.reset:
l := log.WithField("requested_invocation_id", reqResetInvocation).
WithField("this_invocation_id", thisInvocation)
if reqResetInvocation == thisInvocation {
l.Info("reset received, cancelling current invocation")
cancelInvocation()
} else {
l.Debug("received reset for invocation id that is not us, discarding request")
}
}
}()
j.do(invocationCtx)
stopWaitForReset()
wg.Wait()
j.tasksMtx.Lock()
j.activeInvocationId = 0
j.tasksMtx.Unlock()
endSpan()
}
@ -566,11 +575,11 @@ func (j *ActiveSide) Reset(req ActiveSideResetRequest) (*ActiveSideResetResponse
}
type ActiveSideTriggerRequest struct {
What string
}
type ActiveSideSignalResponse struct {
InvocationId uint64
}
func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalResponse, error) {
@ -585,8 +594,6 @@ func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalRes
// err = fmt.Errorf("operation %q is invalid", req.Op)
// }
switch req.What {
case "invocation":
j.tasksMtx.Lock()
var invocationId uint64
if j.activeInvocationId != 0 {
@ -601,9 +608,6 @@ func (j *ActiveSide) Trigger(req ActiveSideTriggerRequest) (*ActiveSideSignalRes
}
j.tasksMtx.Unlock()
return &ActiveSideSignalResponse{InvocationId: invocationId}, nil
default:
return nil, fmt.Errorf("unknown signal %q", req.What)
}
}
func (j *ActiveSide) do(ctx context.Context) {

View File

@ -0,0 +1,187 @@
//
//
// Alternative Design (in "RustGo")
//
// enum InternalMsg {
// Trigger((), chan (TriggerResponse, error)),
// Poll(PollRequest, chan PollResponse),
// Reset(ResetRequest, chan (ResetResponse, error)),
// }
//
// enum State {
// Running{
// invocationId: u32,
// cancelCurrentInvocation: context.CancelFunc
// }
// Waiting{
// nextInvocationId: u32,
// }
// }
//
// for msg := <- t.internalMsgs {
// match (msg, state) {
// ...
// }
// }
package trigger
import (
"context"
"fmt"
"math"
"sync"
"github.com/zrepl/zrepl/daemon/logging"
"github.com/zrepl/zrepl/logger"
)
type T struct {
mtx sync.Mutex
cv sync.Cond
nextInvocationId uint64
activeInvocationId uint64 // 0 <=> inactive
triggerPending bool
contextDone bool
reset chan uint64
stopWaitForReset chan struct{}
cancelCurrentInvocation context.CancelFunc
}
func New() *T {
t := &T{
activeInvocationId: math.MaxUint64,
nextInvocationId: 1,
}
t.cv.L = &t.mtx
return t
}
func (t *T) WaitForTrigger(ctx context.Context) (rctx context.Context, err error) {
t.mtx.Lock()
defer t.mtx.Unlock()
if t.activeInvocationId == 0 {
return nil, fmt.Errorf("must be running when calling this function")
}
t.activeInvocationId = 0
t.cancelCurrentInvocation = nil
if t.contextDone == true {
panic("implementation error: this variable is only true while in WaitForTrigger, and that's a mutually exclusive function")
}
stopWaitingForDone := make(chan struct{})
go func() {
select {
case <-stopWaitingForDone:
case <-ctx.Done():
t.mtx.Lock()
t.contextDone = true
t.cv.Broadcast()
t.mtx.Unlock()
}
}()
defer func() {
t.triggerPending = false
t.contextDone = false
}()
for !t.triggerPending && !t.contextDone {
t.cv.Wait()
}
close(stopWaitingForDone)
if t.contextDone {
if ctx.Err() == nil {
panic("implementation error: contextDone <=> ctx.Err() != nil")
}
return nil, ctx.Err()
}
t.activeInvocationId = t.nextInvocationId
t.nextInvocationId++
rctx, t.cancelCurrentInvocation = context.WithCancel(ctx)
return rctx, nil
}
type TriggerResponse struct {
InvocationId uint64
}
func (t *T) Trigger() (TriggerResponse, error) {
t.mtx.Lock()
defer t.mtx.Unlock()
var invocationId uint64
if t.activeInvocationId != 0 {
invocationId = t.activeInvocationId
} else {
invocationId = t.nextInvocationId
}
// non-blocking send (.Run() must not hold mutex while waiting for signals)
t.triggerPending = true
t.cv.Broadcast()
return TriggerResponse{InvocationId: invocationId}, nil
}
type PollRequest struct {
InvocationId uint64
}
type PollResponse struct {
Done bool
InvocationId uint64
}
func (t *T) Poll(req PollRequest) (res PollResponse) {
t.mtx.Lock()
defer t.mtx.Unlock()
waitForId := req.InvocationId
if req.InvocationId == 0 {
// handle the case where the client doesn't know what the current invocation id is
if t.activeInvocationId != 0 {
waitForId = t.activeInvocationId
} else {
waitForId = t.nextInvocationId
}
}
var done bool
if t.activeInvocationId == 0 {
done = waitForId < t.nextInvocationId
} else {
done = waitForId < t.activeInvocationId
}
return PollResponse{Done: done, InvocationId: waitForId}
}
type ResetRequest struct {
InvocationId uint64
}
type ResetResponse struct {
InvocationId uint64
}
func (t *T) Reset(req ResetRequest) (*ResetResponse, error) {
t.mtx.Lock()
defer t.mtx.Unlock()
resetId := req.InvocationId
if req.InvocationId == 0 {
// handle the case where the client doesn't know what the current invocation id is
resetId = t.activeInvocationId
}
if resetId == 0 {
return nil, fmt.Errorf("no active invocation")
}
if resetId != t.activeInvocationId {
return nil, fmt.Errorf("active invocation (%d) is not the invocation requested for reset (%d); (active invocation '0' indicates no active invocation)", t.activeInvocationId, resetId)
}
t.cancelCurrentInvocation()
return &ResetResponse{InvocationId: resetId}, nil
}

View File

@ -0,0 +1,84 @@
package trigger
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/require"
)
func TestBasics(t *testing.T) {
var wg sync.WaitGroup
defer wg.Wait()
tr := New()
triggered := make(chan int)
waitForTriggerError := make(chan error)
waitForResetCallToBeMadeByMainGoroutine := make(chan struct{})
postResetAssertionsDone := make(chan struct{})
taskCtx := context.Background()
taskCtx, cancelTaskCtx := context.WithCancel(taskCtx)
wg.Add(1)
go func() {
defer wg.Done()
taskCtx := context.WithValue(taskCtx, "mykey", "myvalue")
triggers := 0
outer:
for {
invocationCtx, err := tr.WaitForTrigger(taskCtx)
if err != nil {
waitForTriggerError <- err
return
}
require.Equal(t, invocationCtx.Value("mykey"), "myvalue")
triggers++
triggered <- triggers
switch triggers {
case 1:
continue outer
case 2:
<-waitForResetCallToBeMadeByMainGoroutine
require.Equal(t, context.Canceled, invocationCtx.Err(), "Reset() cancels invocation context")
require.Nil(t, taskCtx.Err(), "Reset() does not cancel task context")
close(postResetAssertionsDone)
}
}
}()
t.Logf("trigger 1")
_, err := tr.Trigger()
require.NoError(t, err)
v := <-triggered
require.Equal(t, 1, v)
t.Logf("trigger 2")
triggerResponse, err := tr.Trigger()
require.NoError(t, err)
v = <-triggered
require.Equal(t, 2, v)
t.Logf("reset")
resetResponse, err := tr.Reset(ResetRequest{InvocationId: triggerResponse.InvocationId})
require.NoError(t, err)
t.Logf("reset response: %#v", resetResponse)
close(waitForResetCallToBeMadeByMainGoroutine)
<-postResetAssertionsDone
t.Logf("cancel the context")
cancelTaskCtx()
wfte := <-waitForTriggerError
require.Equal(t, taskCtx.Err(), wfte)
}