diff --git a/rpc/dataconn/frameconn/frameconn.go b/rpc/dataconn/frameconn/frameconn.go index 4dc83ff..be89318 100644 --- a/rpc/dataconn/frameconn/frameconn.go +++ b/rpc/dataconn/frameconn/frameconn.go @@ -321,7 +321,7 @@ func (c *Conn) Shutdown(deadline time.Time) error { c.shutdown.Begin() // new calls to c.ReadFrame and c.WriteFrame will now return ErrShutdown - // Aquiring writeMtx and readMtx ensures that the last calls exit successfully + // Aquiring writeMtx and readMtx afterwards ensures that already-running calls exit successfully // disable renewing timeouts now, enforce the requested deadline instead // we need to do this before aquiring locks to enforce the timeout on slow diff --git a/rpc/dataconn/stream/stream.go b/rpc/dataconn/stream/stream.go index 554ba22..b46d000 100644 --- a/rpc/dataconn/stream/stream.go +++ b/rpc/dataconn/stream/stream.go @@ -213,10 +213,17 @@ type readFrameResult struct { err error } -func readFrames(reads chan<- readFrameResult, c *heartbeatconn.Conn) { - for { +// readFrames reads from c into reads +// if a read from c encounters an error, noMoreReads is closed before sending the result into reads +func readFrames(reads chan<- readFrameResult, noMoreReads chan<- struct{}, c *heartbeatconn.Conn) { + // noMoreReads is already closed, don't re-close it + defer close(reads) + for { // only exits after a read error, make sure noMoreReads is closed var r readFrameResult r.f, r.err = c.ReadFrame() + if r.err != nil && noMoreReads != nil { + close(noMoreReads) + } reads <- r if r.err != nil { return diff --git a/rpc/dataconn/stream/stream_conn.go b/rpc/dataconn/stream/stream_conn.go index 66f36b7..a7c7879 100644 --- a/rpc/dataconn/stream/stream_conn.go +++ b/rpc/dataconn/stream/stream_conn.go @@ -5,9 +5,13 @@ import ( "context" "fmt" "io" + "net" "sync" + "sync/atomic" "time" + "github.com/pkg/errors" + "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" "github.com/zrepl/zrepl/zfs" @@ -21,6 +25,8 @@ type Conn struct { // filled by per-conn readFrames goroutine frameReads chan readFrameResult + closeState closeState + // readMtx serializes read stream operations because we inherently only // support a single stream at a time over hc. readMtx sync.Mutex @@ -63,15 +69,29 @@ func isConnCleanAfterWrite(err error) bool { return err == nil } -var ErrReadFramesStopped = fmt.Errorf("stream: reading frames stopped") - func (c *Conn) readFrames() { - defer close(c.waitReadFramesDone) - defer close(c.frameReads) - readFrames(c.frameReads, c.hc) + readFrames(c.frameReads, c.waitReadFramesDone, c.hc) } -func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) ([]byte, *ReadStreamError) { +func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) (_ []byte, err *ReadStreamError) { + + // if we are closed while reading, return that as an error + if closeGuard, cse := c.closeState.RWEntry(); cse != nil { + return nil, &ReadStreamError{ + Kind: ReadStreamErrorKindConn, + Err: cse, + } + } else { + defer func(err **ReadStreamError) { + if closed := closeGuard.RWExit(); closed != nil { + *err = &ReadStreamError{ + Kind: ReadStreamErrorKindConn, + Err: closed, + } + } + }(&err) + } + c.readMtx.Lock() defer c.readMtx.Unlock() if !c.readClean { @@ -92,7 +112,7 @@ func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameTyp panic(err) } }() - err := readStream(c.frameReads, c.hc, w, frameType) + err = readStream(c.frameReads, c.hc, w, frameType) c.readClean = isConnCleanAfterRead(err) _ = w.CloseWithError(readMessageSentinel) // always returns nil wg.Wait() @@ -104,23 +124,47 @@ func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameTyp } // WriteStreamTo reads a stream from Conn and writes it to w. -func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) zfs.StreamCopierError { +func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopierError) { + + // if we are closed while writing, return that as an error + if closeGuard, cse := c.closeState.RWEntry(); cse != nil { + return cse + } else { + defer func(err *zfs.StreamCopierError) { + if closed := closeGuard.RWExit(); closed != nil { + *err = closed + } + }(&err) + } + c.readMtx.Lock() defer c.readMtx.Unlock() if !c.readClean { return writeStreamToErrorUnknownState{} } - var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType) - c.readClean = isConnCleanAfterRead(err) + var rse *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType) + c.readClean = isConnCleanAfterRead(rse) // https://golang.org/doc/faq#nil_error - if err == nil { + if rse == nil { return nil } - return err + return rse } -func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) error { +func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) (err error) { + + // if we are closed while writing, return that as an error + if closeGuard, cse := c.closeState.RWEntry(); cse != nil { + return cse + } else { + defer func(err *error) { + if closed := closeGuard.RWExit(); closed != nil { + *err = closed + } + }(&err) + } + c.writeMtx.Lock() defer c.writeMtx.Unlock() if !c.writeClean { @@ -134,7 +178,19 @@ func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameTyp return errConn } -func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) error { +func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) (err error) { + + // if we are closed while reading, return that as an error + if closeGuard, cse := c.closeState.RWEntry(); cse != nil { + return cse + } else { + defer func(err *error) { + if closed := closeGuard.RWExit(); closed != nil { + *err = closed + } + }(&err) + } + c.writeMtx.Lock() defer c.writeMtx.Unlock() if !c.writeClean { @@ -186,12 +242,75 @@ func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType u } } +type closeState struct { + closeCount uint32 +} + +type closeStateErrConnectionClosed struct{} + +var _ zfs.StreamCopierError = (*closeStateErrConnectionClosed)(nil) +var _ error = (*closeStateErrConnectionClosed)(nil) +var _ net.Error = (*closeStateErrConnectionClosed)(nil) + +func (e *closeStateErrConnectionClosed) Error() string { + return "connection closed" +} +func (e *closeStateErrConnectionClosed) IsReadError() bool { return true } +func (e *closeStateErrConnectionClosed) IsWriteError() bool { return true } +func (e *closeStateErrConnectionClosed) Timeout() bool { return false } +func (e *closeStateErrConnectionClosed) Temporary() bool { return false } + +func (s *closeState) CloseEntry() error { + firstCloser := atomic.AddUint32(&s.closeCount, 1) == 1 + if !firstCloser { + return errors.New("duplicate close") + } + return nil +} + +type closeStateEntry struct { + s *closeState + entryCount uint32 +} + +func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) { + entry := &closeStateEntry{s, atomic.LoadUint32(&s.closeCount)} + if entry.entryCount > 0 { + return nil, &closeStateErrConnectionClosed{} + } + return entry, nil +} + +func (e *closeStateEntry) RWExit() zfs.StreamCopierError { + if atomic.LoadUint32(&e.entryCount) == e.entryCount { + // no calls to Close() while running rw operation + return nil + } + return &closeStateErrConnectionClosed{} +} + func (c *Conn) Close() error { + if err := c.closeState.CloseEntry(); err != nil { + return errors.Wrap(err, "stream conn close") + } + + // Shutdown c.hc, which will cause c.readFrames to close c.waitReadFramesDone err := c.hc.Shutdown() - <-c.waitReadFramesDone + // However, c.readFrames may be blocking on a filled c.frameReads + // and since the connection is closed, nobody is going to read from it for read := range c.frameReads { debug("Conn.Close() draining queued read") read.f.Buffer.Free() } - return err + // if we can't close, don't expect c.readFrames to terminate + // this might leak c.readFrames, but we can't do something useful at this point + if err != nil { + return err + } + // shutdown didn't report errors, so c.readFrames will exit due to read error + // This behavior is the contract we have with c.hc. + // If that contract is broken, this read will block indefinitely and + // cause an easily diagnosable goroutine leak (of this goroutine) + <-c.waitReadFramesDone + return nil } diff --git a/rpc/dataconn/stream/stream_test.go b/rpc/dataconn/stream/stream_test.go index d1dabcc..5c4b0ab 100644 --- a/rpc/dataconn/stream/stream_test.go +++ b/rpc/dataconn/stream/stream_test.go @@ -57,7 +57,7 @@ func TestStreamer(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - readFrames(ch, b) + readFrames(ch, nil, b) }() err := readStream(ch, b, &buf, stype) log.WithField("errType", fmt.Sprintf("%T %v", err, err)).Debug("ReadStream returned") @@ -113,7 +113,7 @@ func TestMultiFrameStreamErrTraileror(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - readFrames(ch, b) + readFrames(ch, nil, b) }() err := readStream(ch, b, &buf, stype) t.Logf("%s", err)