diff --git a/rpc/dataconn/stream/stream.go b/rpc/dataconn/stream/stream.go index 4d3e900..554ba22 100644 --- a/rpc/dataconn/stream/stream.go +++ b/rpc/dataconn/stream/stream.go @@ -7,6 +7,7 @@ import ( "io" "net" "strings" + "sync/atomic" "unicode/utf8" "github.com/zrepl/zrepl/logger" @@ -81,8 +82,10 @@ func doWriteStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, } reads := make(chan read, 5) + var stopReading uint32 go func() { - for { + defer close(reads) + for atomic.LoadUint32(&stopReading) == 0 { buffer := bufpool.Get(1 << FramePayloadShift) bufferBytes := buffer.Bytes() n, err := io.ReadFull(stream, bufferBytes) @@ -97,12 +100,21 @@ func doWriteStream(ctx context.Context, c *heartbeatconn.Conn, stream io.Reader, } if err != nil { reads <- read{err: err} // RULE1 - close(reads) return } } }() + defer func() { + // stop reading + atomic.StoreUint32(&stopReading, 1) + // drain in-flight reads + for read := range reads { + debug("doWriteStream: drain read channel") + read.buf.Free() + } + }() + for read := range reads { if read.err == nil { // RULE 1: read.buf is valid diff --git a/rpc/dataconn/stream/stream_conn.go b/rpc/dataconn/stream/stream_conn.go index 7b1b081..66f36b7 100644 --- a/rpc/dataconn/stream/stream_conn.go +++ b/rpc/dataconn/stream/stream_conn.go @@ -189,5 +189,9 @@ func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType u func (c *Conn) Close() error { err := c.hc.Shutdown() <-c.waitReadFramesDone + for read := range c.frameReads { + debug("Conn.Close() draining queued read") + read.f.Buffer.Free() + } return err }