mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-25 09:54:47 +01:00
WIP
This commit is contained in:
parent
1a8d2c5ebe
commit
1826535e6f
@ -146,11 +146,11 @@ outer:
|
||||
j.mainTask.Log().Debug("replicating from lhs to rhs")
|
||||
j.mainTask.Enter("replicate")
|
||||
|
||||
|
||||
replication.Replicate(
|
||||
ctx,
|
||||
replication.NewEndpointPairPull(sender, receiver),
|
||||
replication.NewIncrementalPathReplicator(),
|
||||
nil, //FIXME
|
||||
)
|
||||
|
||||
j.mainTask.Finish()
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
"github.com/zrepl/zrepl/util"
|
||||
"github.com/zrepl/zrepl/cmd/replication"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type PullJob struct {
|
||||
@ -107,7 +109,15 @@ func (j *PullJob) JobStart(ctx context.Context) {
|
||||
|
||||
ticker := time.NewTicker(j.Interval)
|
||||
for {
|
||||
begin := time.Now()
|
||||
j.doRun(ctx)
|
||||
duration := time.Now().Sub(begin)
|
||||
if duration > j.Interval {
|
||||
j.task.Log().
|
||||
WithField("actual_duration", duration).
|
||||
WithField("configured_interval", j.Interval).
|
||||
Warn("pull run took longer than configured interval")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
j.task.Log().WithError(ctx.Err()).Info("context")
|
||||
@ -124,32 +134,86 @@ var STREAMRPC_CONFIG = &streamrpc.ConnConfig{ // FIXME oversight and configurabi
|
||||
TxChunkSize: 4096 * 4096,
|
||||
}
|
||||
|
||||
type streamrpcRWCToNetConnAdatper struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) RemoteAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (streamrpcRWCToNetConnAdatper) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
type streamrpcRWCConnecterToNetConnAdapter struct {
|
||||
RWCConnecter
|
||||
ReadDump, WriteDump string
|
||||
}
|
||||
|
||||
func (s streamrpcRWCConnecterToNetConnAdapter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
rwc, err := s.RWCConnecter.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rwc, err = util.NewReadWriteCloserLogger(rwc, s.ReadDump, s.WriteDump)
|
||||
if err != nil {
|
||||
rwc.Close()
|
||||
return nil, err
|
||||
}
|
||||
return streamrpcRWCToNetConnAdatper{rwc}, nil
|
||||
}
|
||||
|
||||
type tcpConnecter struct {
|
||||
d net.Dialer
|
||||
}
|
||||
|
||||
func (t *tcpConnecter) Connect(ctx context.Context) (net.Conn, error) {
|
||||
return t.d.DialContext(ctx, "tcp", "192.168.122.128:8888")
|
||||
}
|
||||
|
||||
func (j *PullJob) doRun(ctx context.Context) {
|
||||
|
||||
j.task.Enter("run")
|
||||
defer j.task.Finish()
|
||||
|
||||
j.task.Log().Info("connecting")
|
||||
rwc, err := j.Connect.Connect()
|
||||
if err != nil {
|
||||
j.task.Log().WithError(err).Error("error connecting")
|
||||
return
|
||||
//connecter := streamrpcRWCConnecterToNetConnAdapter{
|
||||
// RWCConnecter: j.Connect,
|
||||
// ReadDump: j.Debug.Conn.ReadDump,
|
||||
// WriteDump: j.Debug.Conn.WriteDump,
|
||||
//}
|
||||
|
||||
// FIXME
|
||||
connecter := &tcpConnecter{net.Dialer{
|
||||
Timeout: 2*time.Second,
|
||||
}}
|
||||
|
||||
clientConf := &streamrpc.ClientConfig{
|
||||
MaxConnectAttempts: 5, // FIXME
|
||||
ReconnectBackoffBase: 1*time.Second,
|
||||
ReconnectBackoffFactor: 2,
|
||||
ConnConfig: STREAMRPC_CONFIG,
|
||||
}
|
||||
|
||||
rwc, err = util.NewReadWriteCloserLogger(rwc, j.Debug.Conn.ReadDump, j.Debug.Conn.WriteDump)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
client := RemoteEndpoint{streamrpc.NewClientOnConn(rwc, STREAMRPC_CONFIG)}
|
||||
if j.Debug.RPC.Log {
|
||||
// FIXME implement support
|
||||
// client.SetLogger(j.task.Log(), true)
|
||||
}
|
||||
client, err := streamrpc.NewClient(connecter, clientConf)
|
||||
defer client.Close()
|
||||
|
||||
j.task.Enter("pull")
|
||||
|
||||
sender := RemoteEndpoint{client}
|
||||
|
||||
puller, err := NewReceiverEndpoint(
|
||||
j.Mapping,
|
||||
NewPrefixFilter(j.SnapshotPrefix),
|
||||
@ -161,10 +225,27 @@ func (j *PullJob) doRun(ctx context.Context) {
|
||||
}
|
||||
|
||||
replicator := replication.NewIncrementalPathReplicator()
|
||||
replication.Replicate(context.WithValue(ctx, replication.ContextKeyLog, j.task.Log()), replication.NewEndpointPairPull(client, puller), replicator)
|
||||
ctx = context.WithValue(ctx, replication.ContextKeyLog, j.task.Log())
|
||||
ctx = context.WithValue(ctx, streamrpc.ContextKeyLogger, j.task.Log())
|
||||
ctx, enforceDeadline := util.ContextWithOptionalDeadline(ctx)
|
||||
|
||||
closeRPCWithTimeout(j.task, client, time.Second*1, "")
|
||||
rwc.Close()
|
||||
// Try replicating each file system regardless of j.Interval
|
||||
// (this does not solve the underlying problem that j.Interval is too short,
|
||||
// but it covers the case of initial replication taking longer than all
|
||||
// incremental replications afterwards)
|
||||
allTriedOnce := make(chan struct{})
|
||||
replicationBegin := time.Now()
|
||||
go func() {
|
||||
select {
|
||||
case <-allTriedOnce:
|
||||
enforceDeadline(replicationBegin.Add(j.Interval))
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
replication.Replicate(ctx, replication.NewEndpointPairPull(sender, puller), replicator, allTriedOnce)
|
||||
|
||||
|
||||
client.Close()
|
||||
j.task.Finish()
|
||||
|
||||
j.task.Enter("prune")
|
||||
@ -200,7 +281,8 @@ func closeRPCWithTimeout(task *Task, remote RemoteEndpoint, timeout time.Duratio
|
||||
|
||||
ch := make(chan error)
|
||||
go func() {
|
||||
ch <- remote.Close()
|
||||
remote.Close()
|
||||
ch <- nil
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/util"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"net"
|
||||
)
|
||||
|
||||
type SourceJob struct {
|
||||
@ -138,7 +139,9 @@ func (j *SourceJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Pru
|
||||
|
||||
func (j *SourceJob) serve(ctx context.Context, task *Task) {
|
||||
|
||||
listener, err := j.Serve.Listen()
|
||||
//listener, err := j.Serve.Listen()
|
||||
// FIXME
|
||||
listener, err := net.Listen("tcp", "192.168.122.128:8888")
|
||||
if err != nil {
|
||||
task.Log().WithError(err).Error("error listening")
|
||||
return
|
||||
@ -208,7 +211,7 @@ func (j *SourceJob) handleConnection(rwc io.ReadWriteCloser, task *Task) {
|
||||
|
||||
senderEP := NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix))
|
||||
|
||||
handler := HandlerAdaptor{senderEP}
|
||||
handler := HandlerAdaptor{senderEP, task.Log()}
|
||||
// FIXME logging support or erase config
|
||||
//if j.Debug.RPC.Log {
|
||||
// rpclog := task.Log().WithField("subsystem", "rpc")
|
||||
@ -217,35 +220,8 @@ func (j *SourceJob) handleConnection(rwc io.ReadWriteCloser, task *Task) {
|
||||
|
||||
if err := streamrpc.ServeConn(rwc, STREAMRPC_CONFIG, handler.Handle); err != nil {
|
||||
task.Log().WithError(err).Error("error serving connection")
|
||||
} else {
|
||||
task.Log().Info("client closed connection")
|
||||
}
|
||||
|
||||
// wait for client to close connection
|
||||
// FIXME: we cannot just close it like we would to with a TCP socket because
|
||||
// FIXME: go-nettsh's Close() may overtake the remaining data in the pipe
|
||||
const CLIENT_HANGUP_TIMEOUT = 1 * time.Second
|
||||
task.Log().
|
||||
WithField("timeout", CLIENT_HANGUP_TIMEOUT).
|
||||
Debug("waiting for client to hang up")
|
||||
|
||||
wchan := make(chan error)
|
||||
go func() {
|
||||
var pseudo [1]byte
|
||||
_, err := io.ReadFull(rwc, pseudo[:])
|
||||
wchan <- err
|
||||
}()
|
||||
var werr error
|
||||
select {
|
||||
case werr = <-wchan:
|
||||
// all right
|
||||
case <-time.After(CLIENT_HANGUP_TIMEOUT):
|
||||
werr = errors.New("client did not close connection within timeout")
|
||||
}
|
||||
if werr != nil && werr != io.EOF {
|
||||
task.Log().WithError(werr).
|
||||
Error("error waiting for client to hang up")
|
||||
}
|
||||
task.Log().Info("closing client connection")
|
||||
if err = rwc.Close(); err != nil {
|
||||
task.Log().WithError(err).Error("error force-closing connection")
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,9 @@ package replication
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"container/list"
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
type ReplicationEndpoint interface {
|
||||
@ -77,7 +80,87 @@ type Logger interface{
|
||||
Printf(fmt string, args ... interface{})
|
||||
}
|
||||
|
||||
func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicator) {
|
||||
type replicationWork struct {
|
||||
fs *Filesystem
|
||||
}
|
||||
|
||||
type FilesystemReplicationResult struct {
|
||||
Done bool
|
||||
Retry bool
|
||||
Unfixable bool
|
||||
}
|
||||
|
||||
func handleGenericEndpointError(err error) FilesystemReplicationResult {
|
||||
if _, ok := err.(net.Error); ok {
|
||||
return FilesystemReplicationResult{Retry: true}
|
||||
}
|
||||
return FilesystemReplicationResult{Unfixable: true}
|
||||
}
|
||||
|
||||
func driveFSReplication(ctx context.Context, ws *list.List, allTriedOnce chan struct{}, log Logger, f func(w *replicationWork) FilesystemReplicationResult) {
|
||||
initialLen, fCalls := ws.Len(), 0
|
||||
for ws.Len() > 0 {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("aborting replication due to context error: %s", ctx.Err())
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
w := ws.Remove(ws.Front()).(*replicationWork)
|
||||
res := f(w)
|
||||
fCalls++
|
||||
if fCalls == initialLen {
|
||||
select {
|
||||
case allTriedOnce <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if res.Done {
|
||||
log.Printf("finished replication of %s", w.fs.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
if res.Unfixable {
|
||||
log.Printf("aborting replication of %s after unfixable error", w.fs.Path)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("queuing replication of %s for retry", w.fs.Path)
|
||||
ws.PushBack(w)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveConflict(conflict error) (path []*FilesystemVersion, msg string) {
|
||||
if noCommonAncestor, ok := conflict.(*ConflictNoCommonAncestor); ok {
|
||||
if len(noCommonAncestor.SortedReceiverVersions) == 0 {
|
||||
// FIXME hard-coded replication policy: most recent
|
||||
// snapshot as source
|
||||
var mostRecentSnap *FilesystemVersion
|
||||
for n := len(noCommonAncestor.SortedSenderVersions) -1; n >= 0; n-- {
|
||||
if noCommonAncestor.SortedSenderVersions[n].Type == FilesystemVersion_Snapshot {
|
||||
mostRecentSnap = noCommonAncestor.SortedSenderVersions[n]
|
||||
break
|
||||
}
|
||||
}
|
||||
if mostRecentSnap == nil {
|
||||
return nil, "no snapshots available on sender side"
|
||||
}
|
||||
return []*FilesystemVersion{mostRecentSnap}, fmt.Sprintf("start replication at most recent snapshot %s", mostRecentSnap)
|
||||
}
|
||||
}
|
||||
return nil, "no automated way to handle conflict type"
|
||||
}
|
||||
|
||||
// Replicate replicates filesystems from ep.Sender() to ep.Receiver().
|
||||
//
|
||||
// All filesystems presented by the sending side are replicated,
|
||||
// unless the receiver rejects a Receive request with a *FilteredError.
|
||||
//
|
||||
// If an error occurs when replicating a filesystem, that error is logged to the logger in ctx.
|
||||
// Replicate continues with the replication of the remaining file systems.
|
||||
// Depending on the type of error, failed replications are retried in an unspecified order (currently FIFO).
|
||||
func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicator, allTriedOnce chan struct{}) {
|
||||
|
||||
log := ctx.Value(ContextKeyLog).(Logger)
|
||||
|
||||
@ -93,18 +176,27 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
|
||||
return
|
||||
}
|
||||
|
||||
wq := list.New()
|
||||
for _, fs := range sfss {
|
||||
wq.PushBack(&replicationWork{
|
||||
fs: fs,
|
||||
})
|
||||
}
|
||||
|
||||
driveFSReplication(ctx, wq, allTriedOnce, log, func(w *replicationWork) FilesystemReplicationResult {
|
||||
fs := w.fs
|
||||
|
||||
log.Printf("replicating %s", fs.Path)
|
||||
|
||||
sfsvs, err := ep.Sender().ListFilesystemVersions(ctx, fs.Path)
|
||||
if err != nil {
|
||||
log.Printf("sender error %s", err)
|
||||
continue
|
||||
log.Printf("cannot get remote filesystem versions: %s", err)
|
||||
return handleGenericEndpointError(err)
|
||||
}
|
||||
|
||||
if len(sfsvs) <= 1 {
|
||||
log.Printf("sender does not have any versions")
|
||||
continue
|
||||
return FilesystemReplicationResult{Unfixable: true}
|
||||
}
|
||||
|
||||
receiverFSExists := false
|
||||
@ -118,47 +210,35 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
|
||||
if receiverFSExists {
|
||||
rfsvs, err = ep.Receiver().ListFilesystemVersions(ctx, fs.Path)
|
||||
if err != nil {
|
||||
log.Printf("receiver error %s", err)
|
||||
if _, ok := err.(FilteredError); ok {
|
||||
// Remote does not map filesystem, don't try to tx it
|
||||
continue
|
||||
log.Printf("receiver does not map %s", fs.Path)
|
||||
return FilesystemReplicationResult{Done: true}
|
||||
}
|
||||
// log and ignore
|
||||
continue
|
||||
log.Printf("receiver error %s", err)
|
||||
return handleGenericEndpointError(err)
|
||||
}
|
||||
} else {
|
||||
rfsvs = []*FilesystemVersion{}
|
||||
}
|
||||
|
||||
path, conflict := IncrementalPath(rfsvs, sfsvs)
|
||||
if noCommonAncestor, ok := conflict.(*ConflictNoCommonAncestor); ok {
|
||||
if len(noCommonAncestor.SortedReceiverVersions) == 0 {
|
||||
log.Printf("initial replication")
|
||||
// FIXME hard-coded replication policy: most recent
|
||||
// snapshot as source
|
||||
var mostRecentSnap *FilesystemVersion
|
||||
for n := len(sfsvs) -1; n >= 0; n-- {
|
||||
if sfsvs[n].Type == FilesystemVersion_Snapshot {
|
||||
mostRecentSnap = sfsvs[n]
|
||||
break
|
||||
}
|
||||
}
|
||||
if mostRecentSnap == nil {
|
||||
log.Printf("no snapshot on sender side")
|
||||
continue
|
||||
}
|
||||
log.Printf("starting at most recent snapshot %s", mostRecentSnap)
|
||||
path = []*FilesystemVersion{mostRecentSnap}
|
||||
if conflict != nil {
|
||||
log.Printf("conflict: %s", conflict)
|
||||
var msg string
|
||||
path, msg = resolveConflict(conflict)
|
||||
if path != nil {
|
||||
log.Printf("conflict resolved: %s", msg)
|
||||
} else {
|
||||
log.Printf("%s", msg)
|
||||
}
|
||||
} else if conflict != nil {
|
||||
log.Printf("unresolvable conflict: %s", conflict)
|
||||
// handle or ignore for now
|
||||
continue
|
||||
}
|
||||
if path == nil {
|
||||
return FilesystemReplicationResult{Unfixable: true}
|
||||
}
|
||||
|
||||
ipr.Replicate(ctx, ep.Sender(), ep.Receiver(), NewCopier(), fs, path)
|
||||
return ipr.Replicate(ctx, ep.Sender(), ep.Receiver(), NewCopier(), fs, path)
|
||||
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
@ -185,7 +265,7 @@ func NewCopier() Copier {
|
||||
}
|
||||
|
||||
type IncrementalPathReplicator interface {
|
||||
Replicate(ctx context.Context, sender Sender, receiver Receiver, copier Copier, fs *Filesystem, path []*FilesystemVersion)
|
||||
Replicate(ctx context.Context, sender Sender, receiver Receiver, copier Copier, fs *Filesystem, path []*FilesystemVersion) FilesystemReplicationResult
|
||||
}
|
||||
|
||||
type incrementalPathReplicator struct{}
|
||||
@ -194,14 +274,13 @@ func NewIncrementalPathReplicator() IncrementalPathReplicator {
|
||||
return incrementalPathReplicator{}
|
||||
}
|
||||
|
||||
func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, receiver Receiver, copier Copier, fs *Filesystem, path []*FilesystemVersion) {
|
||||
func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, receiver Receiver, copier Copier, fs *Filesystem, path []*FilesystemVersion) FilesystemReplicationResult {
|
||||
|
||||
log := ctx.Value(ContextKeyLog).(Logger)
|
||||
|
||||
if len(path) == 0 {
|
||||
log.Printf("nothing to do")
|
||||
// nothing to do
|
||||
return
|
||||
return FilesystemReplicationResult{Done: true}
|
||||
}
|
||||
|
||||
if len(path) == 1 {
|
||||
@ -215,8 +294,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r
|
||||
sres, sstream, err := sender.Send(ctx, sr)
|
||||
if err != nil {
|
||||
log.Printf("send request failed: %s", err)
|
||||
// FIXME must close connection...
|
||||
return
|
||||
return handleGenericEndpointError(err)
|
||||
}
|
||||
|
||||
rr := &ReceiveReq{
|
||||
@ -225,20 +303,19 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r
|
||||
}
|
||||
err = receiver.Receive(ctx, rr, sstream)
|
||||
if err != nil {
|
||||
// FIXME this failure could be due to an unexpected exit of ZFS on the sending side
|
||||
// FIXME which is transported through the streamrpc protocol, and known to the sendStream.(*streamrpc.streamReader),
|
||||
// FIXME but the io.Reader interface design doesn not allow us to infer that it is a *streamrpc.streamReader right now
|
||||
log.Printf("receive request failed (might also be error on sender...): %s", err)
|
||||
// FIXME must close connection
|
||||
return
|
||||
log.Printf("receive request failed (might also be error on sender): %s", err)
|
||||
sstream.Close()
|
||||
// This failure could be due to
|
||||
// - an unexpected exit of ZFS on the sending side
|
||||
// - an unexpected exit of ZFS on the receiving side
|
||||
// - a connectivity issue
|
||||
return handleGenericEndpointError(err)
|
||||
}
|
||||
|
||||
return
|
||||
return FilesystemReplicationResult{Done: true}
|
||||
}
|
||||
|
||||
usedResumeToken := false
|
||||
|
||||
incrementalLoop:
|
||||
for j := 0; j < len(path)-1; j++ {
|
||||
rt := ""
|
||||
if !usedResumeToken { // only send resume token for first increment
|
||||
@ -254,8 +331,7 @@ incrementalLoop:
|
||||
sres, sstream, err := sender.Send(ctx, sr)
|
||||
if err != nil {
|
||||
log.Printf("send request failed: %s", err)
|
||||
// handle and ignore
|
||||
break incrementalLoop
|
||||
return handleGenericEndpointError(err)
|
||||
}
|
||||
// try to consume stream
|
||||
|
||||
@ -266,10 +342,11 @@ incrementalLoop:
|
||||
err = receiver.Receive(ctx, rr, sstream)
|
||||
if err != nil {
|
||||
log.Printf("receive request failed: %s", err)
|
||||
// handle and ignore
|
||||
break incrementalLoop
|
||||
return handleGenericEndpointError(err) // FIXME resume state on receiver -> update ResumeToken
|
||||
}
|
||||
|
||||
// FIXME handle properties from sres
|
||||
}
|
||||
|
||||
return FilesystemReplicationResult{Done: true}
|
||||
}
|
||||
|
83
util/contextflexibletimeout.go
Normal file
83
util/contextflexibletimeout.go
Normal file
@ -0,0 +1,83 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type contextWithOptionalDeadline struct {
|
||||
context.Context
|
||||
|
||||
m sync.Mutex
|
||||
deadline time.Time
|
||||
|
||||
done chan struct{}
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *contextWithOptionalDeadline) Deadline() (deadline time.Time, ok bool) {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
return c.deadline, !c.deadline.IsZero()
|
||||
}
|
||||
|
||||
func (c *contextWithOptionalDeadline) Err() error {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *contextWithOptionalDeadline) Done() (<-chan struct{}) {
|
||||
return c.done
|
||||
}
|
||||
|
||||
func ContextWithOptionalDeadline(pctx context.Context) (ctx context.Context, enforceDeadline func(deadline time.Time)) {
|
||||
|
||||
// mctx can only be cancelled by cancelMctx, not by a potential cancel of pctx
|
||||
rctx := &contextWithOptionalDeadline{
|
||||
Context: pctx,
|
||||
done: make(chan struct{}),
|
||||
err: nil,
|
||||
}
|
||||
enforceDeadline = func(deadline time.Time) {
|
||||
|
||||
// Set deadline and prohibit multiple calls
|
||||
rctx.m.Lock()
|
||||
alreadyCalled := !rctx.deadline.IsZero()
|
||||
if !alreadyCalled {
|
||||
rctx.deadline = deadline
|
||||
}
|
||||
rctx.m.Unlock()
|
||||
if alreadyCalled {
|
||||
return
|
||||
}
|
||||
|
||||
// Deadline in past?
|
||||
sleepTime := deadline.Sub(time.Now())
|
||||
if sleepTime <= 0 {
|
||||
rctx.m.Lock()
|
||||
rctx.err = context.DeadlineExceeded
|
||||
rctx.m.Unlock()
|
||||
close(rctx.done)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
// Set a timer and wait for timer or parent context to be cancelled
|
||||
timer := time.NewTimer(sleepTime)
|
||||
var setErr error
|
||||
select {
|
||||
case <-pctx.Done():
|
||||
timer.Stop()
|
||||
setErr = pctx.Err()
|
||||
case <-timer.C:
|
||||
setErr = context.DeadlineExceeded
|
||||
}
|
||||
rctx.m.Lock()
|
||||
rctx.err = setErr
|
||||
rctx.m.Unlock()
|
||||
close(rctx.done)
|
||||
}()
|
||||
}
|
||||
return rctx, enforceDeadline
|
||||
}
|
84
util/contextflexibletimeout_test.go
Normal file
84
util/contextflexibletimeout_test.go
Normal file
@ -0,0 +1,84 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"context"
|
||||
"time"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestContextWithOptionalDeadline(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
cctx, enforceDeadline := ContextWithOptionalDeadline(ctx)
|
||||
|
||||
begin := time.Now()
|
||||
var receivedCancellation time.Time
|
||||
var cancellationError error
|
||||
go func() {
|
||||
select {
|
||||
case <- cctx.Done():
|
||||
receivedCancellation = time.Now()
|
||||
cancellationError = cctx.Err()
|
||||
case <- time.After(600*time.Millisecond):
|
||||
t.Fatalf("should have been cancelled by deadline")
|
||||
}
|
||||
}()
|
||||
time.Sleep(100*time.Millisecond)
|
||||
if !receivedCancellation.IsZero() {
|
||||
t.Fatalf("no enforcement means no cancellation")
|
||||
}
|
||||
require.Nil(t, cctx.Err(), "no error while not cancelled")
|
||||
dl, ok := cctx.Deadline()
|
||||
require.False(t, ok)
|
||||
require.Zero(t, dl)
|
||||
enforceDeadline(begin.Add(200*time.Millisecond))
|
||||
// second call must be ignored, i.e. we expect the deadline to be at begin+200ms, not begin+400ms
|
||||
enforceDeadline(begin.Add(400*time.Millisecond))
|
||||
|
||||
time.Sleep(300*time.Millisecond) // 100ms margin for scheduler
|
||||
if receivedCancellation.Sub(begin) > 250*time.Millisecond {
|
||||
t.Fatalf("cancellation is beyond acceptable scheduler latency")
|
||||
}
|
||||
require.Equal(t, context.DeadlineExceeded, cancellationError)
|
||||
}
|
||||
|
||||
func TestContextWithOptionalDeadlineNegativeDeadline(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cctx, enforceDeadline := ContextWithOptionalDeadline(ctx)
|
||||
enforceDeadline(time.Now().Add(-10*time.Second))
|
||||
select {
|
||||
case <-cctx.Done():
|
||||
default:
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextWithOptionalDeadlineParentCancellation(t *testing.T) {
|
||||
|
||||
pctx, cancel := context.WithCancel(context.Background())
|
||||
cctx, enforceDeadline := ContextWithOptionalDeadline(pctx)
|
||||
|
||||
// 0 ms
|
||||
start := time.Now()
|
||||
enforceDeadline(start.Add(400*time.Millisecond))
|
||||
time.Sleep(100*time.Millisecond)
|
||||
cancel() // cancel @ ~100ms
|
||||
time.Sleep(100*time.Millisecond) // give 100ms time to propagate cancel
|
||||
// @ ~200ms
|
||||
select {
|
||||
case <-cctx.Done():
|
||||
assert.True(t, time.Now().Before(start.Add(300*time.Millisecond)))
|
||||
assert.Equal(t, context.Canceled, cctx.Err())
|
||||
default:
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestContextWithOptionalDeadlineValue(t *testing.T) {
|
||||
pctx := context.WithValue(context.Background(), "key", "value")
|
||||
cctx, _ := ContextWithOptionalDeadline(pctx)
|
||||
assert.Equal(t, "value", cctx.Value("key"))
|
||||
}
|
Loading…
Reference in New Issue
Block a user