diff --git a/daemon/daemon.go b/daemon/daemon.go index 8ada3c9..41cf667 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -7,6 +7,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/daemon/job" + "github.com/zrepl/zrepl/daemon/job/wakeup" "github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/logger" "github.com/zrepl/zrepl/version" @@ -100,13 +101,13 @@ type jobs struct { // m protects all fields below it m sync.RWMutex - wakeups map[string]job.WakeupFunc // by Job.Name + wakeups map[string]wakeup.Func // by Job.Name jobs map[string]job.Job } func newJobs() *jobs { return &jobs{ - wakeups: make(map[string]job.WakeupFunc), + wakeups: make(map[string]wakeup.Func), jobs: make(map[string]job.Job), } } @@ -193,8 +194,8 @@ func (s *jobs) start(ctx context.Context, j job.Job, internal bool) { s.jobs[jobName] = j ctx = job.WithLogger(ctx, jobLog) - ctx, wakeupChan := job.WithWakeup(ctx) - s.wakeups[jobName] = wakeupChan + ctx, wakeup := wakeup.Context(ctx) + s.wakeups[jobName] = wakeup s.wg.Add(1) go func() { diff --git a/daemon/job/active.go b/daemon/job/active.go index 791530e..b50216d 100644 --- a/daemon/job/active.go +++ b/daemon/job/active.go @@ -6,6 +6,7 @@ import ( "github.com/problame/go-streamrpc" "github.com/prometheus/client_golang/prometheus" "github.com/zrepl/zrepl/config" + "github.com/zrepl/zrepl/daemon/job/wakeup" "github.com/zrepl/zrepl/daemon/transport/connecter" "github.com/zrepl/zrepl/daemon/filters" "github.com/zrepl/zrepl/daemon/pruner" @@ -233,7 +234,7 @@ outer: log.WithError(ctx.Err()).Info("context") break outer - case <-WaitWakeup(ctx): + case <-wakeup.Wait(ctx): case <-periodicDone: } invocationCount++ diff --git a/daemon/job/job.go b/daemon/job/job.go index 85334a5..5b3a684 100644 --- a/daemon/job/job.go +++ b/daemon/job/job.go @@ -3,7 +3,6 @@ package job import ( "context" "encoding/json" - "errors" "fmt" "github.com/prometheus/client_golang/prometheus" "github.com/zrepl/zrepl/logger" @@ -15,7 +14,6 @@ type contextKey int const ( contextKeyLog contextKey = iota - contextKeyWakeup ) func GetLogger(ctx context.Context) Logger { @@ -29,22 +27,6 @@ func WithLogger(ctx context.Context, l Logger) context.Context { return context.WithValue(ctx, contextKeyLog, l) } -type WakeupFunc func() error - -var AlreadyWokenUp = errors.New("already woken up") - -func WithWakeup(ctx context.Context) (context.Context, WakeupFunc) { - wc := make(chan struct{}) - wuf := func() error { - select { - case wc <- struct{}{}: - return nil - default: - return AlreadyWokenUp - } - } - return context.WithValue(ctx, contextKeyWakeup, wc), wuf -} type Job interface { Name() string @@ -119,12 +101,3 @@ func (s *Status) UnmarshalJSON(in []byte) (err error) { } return err } - -func WaitWakeup(ctx context.Context) <-chan struct{} { - wc, ok := ctx.Value(contextKeyWakeup).(chan struct{}) - if !ok { - wc = make(chan struct{}) - } - return wc -} - diff --git a/daemon/job/wakeup/wakeup.go b/daemon/job/wakeup/wakeup.go new file mode 100644 index 0000000..a099b53 --- /dev/null +++ b/daemon/job/wakeup/wakeup.go @@ -0,0 +1,35 @@ +package wakeup + +import ( + "context" + "errors" +) + +type contextKey int + +const contextKeyWakeup contextKey = iota + +func Wait(ctx context.Context) <-chan struct{} { + wc, ok := ctx.Value(contextKeyWakeup).(chan struct{}) + if !ok { + wc = make(chan struct{}) + } + return wc +} + +type Func func() error + +var AlreadyWokenUp = errors.New("already woken up") + +func Context(ctx context.Context) (context.Context, Func) { + wc := make(chan struct{}) + wuf := func() error { + select { + case wc <- struct{}{}: + return nil + default: + return AlreadyWokenUp + } + } + return context.WithValue(ctx, contextKeyWakeup, wc), wuf +}