mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-11 00:39:51 +01:00
195 lines
5.0 KiB
Go
195 lines
5.0 KiB
Go
|
package stream
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
|
||
|
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||
|
"github.com/zrepl/zrepl/zfs"
|
||
|
)
|
||
|
|
||
|
type Conn struct {
|
||
|
hc *heartbeatconn.Conn
|
||
|
|
||
|
// whether the per-conn readFrames goroutine completed
|
||
|
waitReadFramesDone chan struct{}
|
||
|
// filled by per-conn readFrames goroutine
|
||
|
frameReads chan readFrameResult
|
||
|
|
||
|
// readMtx serializes read stream operations because we inherently only
|
||
|
// support a single stream at a time over hc.
|
||
|
readMtx sync.Mutex
|
||
|
readClean bool
|
||
|
allowWriteStreamTo bool
|
||
|
|
||
|
// writeMtx serializes write stream operations because we inherently only
|
||
|
// support a single stream at a time over hc.
|
||
|
writeMtx sync.Mutex
|
||
|
writeClean bool
|
||
|
}
|
||
|
|
||
|
var readMessageSentinel = fmt.Errorf("read stream complete")
|
||
|
|
||
|
type writeStreamToErrorUnknownState struct{}
|
||
|
|
||
|
func (e writeStreamToErrorUnknownState) Error() string {
|
||
|
return "dataconn read stream: connection is in unknown state"
|
||
|
}
|
||
|
|
||
|
func (e writeStreamToErrorUnknownState) IsReadError() bool { return true }
|
||
|
|
||
|
func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false }
|
||
|
|
||
|
func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn {
|
||
|
hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout)
|
||
|
conn := &Conn{
|
||
|
hc: hc, readClean: true, writeClean: true,
|
||
|
waitReadFramesDone: make(chan struct{}),
|
||
|
frameReads: make(chan readFrameResult, 5), // FIXME constant
|
||
|
}
|
||
|
go conn.readFrames()
|
||
|
return conn
|
||
|
}
|
||
|
|
||
|
func isConnCleanAfterRead(res *ReadStreamError) bool {
|
||
|
return res == nil || res.Kind == ReadStreamErrorKindSource || res.Kind == ReadStreamErrorKindStreamErrTrailerEncoding
|
||
|
}
|
||
|
|
||
|
func isConnCleanAfterWrite(err error) bool {
|
||
|
return err == nil
|
||
|
}
|
||
|
|
||
|
var ErrReadFramesStopped = fmt.Errorf("stream: reading frames stopped")
|
||
|
|
||
|
func (c *Conn) readFrames() {
|
||
|
defer close(c.waitReadFramesDone)
|
||
|
defer close(c.frameReads)
|
||
|
readFrames(c.frameReads, c.hc)
|
||
|
}
|
||
|
|
||
|
func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameType uint32) ([]byte, *ReadStreamError) {
|
||
|
c.readMtx.Lock()
|
||
|
defer c.readMtx.Unlock()
|
||
|
if !c.readClean {
|
||
|
return nil, &ReadStreamError{
|
||
|
Kind: ReadStreamErrorKindConn,
|
||
|
Err: fmt.Errorf("dataconn read message: connection is in unknown state"),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
r, w := io.Pipe()
|
||
|
var buf bytes.Buffer
|
||
|
var wg sync.WaitGroup
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
lr := io.LimitReader(r, int64(maxSize))
|
||
|
if _, err := io.Copy(&buf, lr); err != nil && err != readMessageSentinel {
|
||
|
panic(err)
|
||
|
}
|
||
|
}()
|
||
|
err := readStream(c.frameReads, c.hc, w, frameType)
|
||
|
c.readClean = isConnCleanAfterRead(err)
|
||
|
w.CloseWithError(readMessageSentinel)
|
||
|
wg.Wait()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
} else {
|
||
|
return buf.Bytes(), nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WriteStreamTo reads a stream from Conn and writes it to w.
|
||
|
func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) zfs.StreamCopierError {
|
||
|
c.readMtx.Lock()
|
||
|
defer c.readMtx.Unlock()
|
||
|
if !c.readClean {
|
||
|
return writeStreamToErrorUnknownState{}
|
||
|
}
|
||
|
var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
|
||
|
c.readClean = isConnCleanAfterRead(err)
|
||
|
|
||
|
// https://golang.org/doc/faq#nil_error
|
||
|
if err == nil {
|
||
|
return nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) error {
|
||
|
c.writeMtx.Lock()
|
||
|
defer c.writeMtx.Unlock()
|
||
|
if !c.writeClean {
|
||
|
return fmt.Errorf("dataconn write message: connection is in unknown state")
|
||
|
}
|
||
|
errBuf, errConn := writeStream(ctx, c.hc, buf, frameType)
|
||
|
if errBuf != nil {
|
||
|
panic(errBuf)
|
||
|
}
|
||
|
c.writeClean = isConnCleanAfterWrite(errConn)
|
||
|
return errConn
|
||
|
}
|
||
|
|
||
|
func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) error {
|
||
|
c.writeMtx.Lock()
|
||
|
defer c.writeMtx.Unlock()
|
||
|
if !c.writeClean {
|
||
|
return fmt.Errorf("dataconn send stream: connection is in unknown state")
|
||
|
}
|
||
|
|
||
|
// avoid io.Pipe if zfs.StreamCopier is an io.Reader
|
||
|
var r io.Reader
|
||
|
var w *io.PipeWriter
|
||
|
streamCopierErrChan := make(chan zfs.StreamCopierError, 1)
|
||
|
if reader, ok := src.(io.Reader); ok {
|
||
|
r = reader
|
||
|
streamCopierErrChan <- nil
|
||
|
close(streamCopierErrChan)
|
||
|
} else {
|
||
|
r, w = io.Pipe()
|
||
|
go func() {
|
||
|
streamCopierErrChan <- src.WriteStreamTo(w)
|
||
|
w.Close()
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
type writeStreamRes struct {
|
||
|
errStream, errConn error
|
||
|
}
|
||
|
writeStreamErrChan := make(chan writeStreamRes, 1)
|
||
|
go func() {
|
||
|
var res writeStreamRes
|
||
|
res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType)
|
||
|
if w != nil {
|
||
|
w.CloseWithError(res.errStream)
|
||
|
}
|
||
|
writeStreamErrChan <- res
|
||
|
}()
|
||
|
|
||
|
writeRes := <-writeStreamErrChan
|
||
|
streamCopierErr := <-streamCopierErrChan
|
||
|
c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct?
|
||
|
if streamCopierErr != nil && streamCopierErr.IsReadError() {
|
||
|
return streamCopierErr // something on our side is bad
|
||
|
} else {
|
||
|
if writeRes.errStream != nil {
|
||
|
return writeRes.errStream
|
||
|
} else if writeRes.errConn != nil {
|
||
|
return writeRes.errConn
|
||
|
}
|
||
|
// TODO combined error?
|
||
|
return streamCopierErr
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *Conn) Close() error {
|
||
|
err := c.hc.Shutdown()
|
||
|
<-c.waitReadFramesDone
|
||
|
return err
|
||
|
}
|