zrepl/util/optionaldeadline/optionaldeadline.go
2019-03-27 13:12:26 +01:00

84 lines
1.7 KiB
Go

package optionaldeadline
import (
"context"
"sync"
"time"
)
type contextWithOptionalDeadline struct {
context.Context
m sync.Mutex
deadline time.Time
done chan struct{}
err error
}
func (c *contextWithOptionalDeadline) Deadline() (deadline time.Time, ok bool) {
c.m.Lock()
defer c.m.Unlock()
return c.deadline, !c.deadline.IsZero()
}
func (c *contextWithOptionalDeadline) Err() error {
c.m.Lock()
defer c.m.Unlock()
return c.err
}
func (c *contextWithOptionalDeadline) Done() <-chan struct{} {
return c.done
}
func ContextWithOptionalDeadline(pctx context.Context) (ctx context.Context, enforceDeadline func(deadline time.Time)) {
// mctx can only be cancelled by cancelMctx, not by a potential cancel of pctx
rctx := &contextWithOptionalDeadline{
Context: pctx,
done: make(chan struct{}),
err: nil,
}
enforceDeadline = func(deadline time.Time) {
// Set deadline and prohibit multiple calls
rctx.m.Lock()
alreadyCalled := !rctx.deadline.IsZero()
if !alreadyCalled {
rctx.deadline = deadline
}
rctx.m.Unlock()
if alreadyCalled {
return
}
// Deadline in past?
sleepTime := time.Until(deadline)
if sleepTime <= 0 {
rctx.m.Lock()
rctx.err = context.DeadlineExceeded
rctx.m.Unlock()
close(rctx.done)
return
}
go func() {
// Set a timer and wait for timer or parent context to be cancelled
timer := time.NewTimer(sleepTime)
var setErr error
select {
case <-pctx.Done():
timer.Stop()
setErr = pctx.Err()
case <-timer.C:
setErr = context.DeadlineExceeded
}
rctx.m.Lock()
rctx.err = setErr
rctx.m.Unlock()
close(rctx.done)
}()
}
return rctx, enforceDeadline
}