diff --git a/replication/logic/replication_logic.go b/replication/logic/replication_logic.go index ff0e3d6..13458b7 100644 --- a/replication/logic/replication_logic.go +++ b/replication/logic/replication_logic.go @@ -14,6 +14,7 @@ import ( "github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/replication/report" "github.com/zrepl/zrepl/util/bytecounter" + "github.com/zrepl/zrepl/util/chainlock" "github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/semaphore" "github.com/zrepl/zrepl/zfs" @@ -147,8 +148,12 @@ type Step struct { parent *Filesystem from, to *pdu.FilesystemVersion // compat - byteCounter bytecounter.StreamCopier expectedSize int64 // 0 means no size estimate present / possible + + // byteCounter is nil initially, and set later in Step.doReplication + // => concurrent read of that pointer from Step.ReportInfo must be protected + byteCounter bytecounter.StreamCopier + byteCounterMtx chainlock.L } func (s *Step) TargetEquals(other driver.Step) bool { @@ -172,10 +177,15 @@ func (s *Step) Step(ctx context.Context) error { } func (s *Step) ReportInfo() *report.StepInfo { + + // get current byteCounter value var byteCounter int64 + s.byteCounterMtx.Lock() if s.byteCounter != nil { byteCounter = s.byteCounter.Count() } + s.byteCounterMtx.Unlock() + // FIXME stick to zfs convention of from and to from := "" if s.from != nil { @@ -457,8 +467,12 @@ func (s *Step) doReplication(ctx context.Context) error { defer sstreamCopier.Close() // Install a byte counter to track progress + for status report - s.byteCounter = bytecounter.NewStreamCopier(sstreamCopier) + byteCountingStream := bytecounter.NewStreamCopier(sstreamCopier) + s.byteCounterMtx.Lock() + s.byteCounter = byteCountingStream + s.byteCounterMtx.Unlock() defer func() { + defer s.byteCounterMtx.Lock().Unlock() s.parent.promBytesReplicated.Add(float64(s.byteCounter.Count())) }() @@ -467,7 +481,7 @@ func (s *Step) doReplication(ctx context.Context) error { ClearResumeToken: !sres.UsedResumeToken, } log.Debug("initiate receive request") - _, err = s.receiver.Receive(ctx, rr, s.byteCounter) + _, err = s.receiver.Receive(ctx, rr, byteCountingStream) if err != nil { log. WithError(err).