replication/logic: fix race when reading byte counter pointer for report

fixes #214
This commit is contained in:
Christian Schwarz 2019-09-28 15:00:14 +02:00
parent f976212ec9
commit f9c7766073

View File

@ -14,6 +14,7 @@ import (
"github.com/zrepl/zrepl/replication/logic/pdu"
"github.com/zrepl/zrepl/replication/report"
"github.com/zrepl/zrepl/util/bytecounter"
"github.com/zrepl/zrepl/util/chainlock"
"github.com/zrepl/zrepl/util/envconst"
"github.com/zrepl/zrepl/util/semaphore"
"github.com/zrepl/zrepl/zfs"
@ -147,8 +148,12 @@ type Step struct {
parent *Filesystem
from, to *pdu.FilesystemVersion // compat
byteCounter bytecounter.StreamCopier
expectedSize int64 // 0 means no size estimate present / possible
// byteCounter is nil initially, and set later in Step.doReplication
// => concurrent read of that pointer from Step.ReportInfo must be protected
byteCounter bytecounter.StreamCopier
byteCounterMtx chainlock.L
}
func (s *Step) TargetEquals(other driver.Step) bool {
@ -172,10 +177,15 @@ func (s *Step) Step(ctx context.Context) error {
}
func (s *Step) ReportInfo() *report.StepInfo {
// get current byteCounter value
var byteCounter int64
s.byteCounterMtx.Lock()
if s.byteCounter != nil {
byteCounter = s.byteCounter.Count()
}
s.byteCounterMtx.Unlock()
// FIXME stick to zfs convention of from and to
from := ""
if s.from != nil {
@ -457,8 +467,12 @@ func (s *Step) doReplication(ctx context.Context) error {
defer sstreamCopier.Close()
// Install a byte counter to track progress + for status report
s.byteCounter = bytecounter.NewStreamCopier(sstreamCopier)
byteCountingStream := bytecounter.NewStreamCopier(sstreamCopier)
s.byteCounterMtx.Lock()
s.byteCounter = byteCountingStream
s.byteCounterMtx.Unlock()
defer func() {
defer s.byteCounterMtx.Lock().Unlock()
s.parent.promBytesReplicated.Add(float64(s.byteCounter.Count()))
}()
@ -467,7 +481,7 @@ func (s *Step) doReplication(ctx context.Context) error {
ClearResumeToken: !sres.UsedResumeToken,
}
log.Debug("initiate receive request")
_, err = s.receiver.Receive(ctx, rr, s.byteCounter)
_, err = s.receiver.Receive(ctx, rr, byteCountingStream)
if err != nil {
log.
WithError(err).