mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-25 01:44:43 +01:00
rpc + zfs: drop zfs.StreamCopier, use io.ReadCloser instead
This commit is contained in:
parent
c1c9d99a6f
commit
5aaac49382
@ -4,6 +4,7 @@ package endpoint
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"path"
|
||||
"sync"
|
||||
|
||||
@ -241,7 +242,7 @@ func sendArgsFromPDUAndValidateExistsAndGetVersion(ctx context.Context, fs strin
|
||||
return version, nil
|
||||
}
|
||||
|
||||
func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
|
||||
_, err := s.filterCheckFS(r.Filesystem)
|
||||
if err != nil {
|
||||
@ -339,11 +340,11 @@ func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.St
|
||||
|
||||
// step holds & replication cursor released / moved forward in s.SendCompleted => s.moveCursorAndReleaseSendHolds
|
||||
|
||||
streamCopier, err := zfs.ZFSSend(ctx, sendArgs)
|
||||
sendStream, err := zfs.ZFSSend(ctx, sendArgs)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "zfs send failed")
|
||||
}
|
||||
return res, streamCopier, nil
|
||||
return res, sendStream, nil
|
||||
}
|
||||
|
||||
func (p *Sender) SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error) {
|
||||
@ -476,7 +477,7 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs
|
||||
return &pdu.ReplicationCursorRes{Result: &pdu.ReplicationCursorRes_Guid{Guid: cursor.Guid}}, nil
|
||||
}
|
||||
|
||||
func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, _ io.ReadCloser) (*pdu.ReceiveRes, error) {
|
||||
return nil, fmt.Errorf("sender does not implement Receive()")
|
||||
}
|
||||
|
||||
@ -680,13 +681,13 @@ func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq)
|
||||
return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver")
|
||||
}
|
||||
|
||||
func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
return nil, nil, fmt.Errorf("receiver does not implement Send()")
|
||||
}
|
||||
|
||||
var maxConcurrentZFSRecvSemaphore = semaphore.New(envconst.Int64("ZREPL_ENDPOINT_MAX_CONCURRENT_RECV", 10))
|
||||
|
||||
func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error) {
|
||||
getLogger(ctx).Debug("incoming Receive")
|
||||
defer receive.Close()
|
||||
|
||||
|
@ -136,7 +136,7 @@ func makeResumeSituation(ctx *platformtest.Context, src dummySnapshotSituation,
|
||||
return situation
|
||||
}
|
||||
|
||||
limitedCopier := zfs.NewReadCloserCopier(limitio.ReadCloser(copier, src.dummyDataLen/2))
|
||||
limitedCopier := limitio.ReadCloser(copier, src.dummyDataLen/2)
|
||||
defer limitedCopier.Close()
|
||||
|
||||
require.NotNil(ctx, sendArgs.To)
|
||||
|
@ -35,7 +35,7 @@ func SendArgsValidationEncryptedSendOfUnencryptedDatasetForbidden(ctx *platformt
|
||||
ResumeToken: "",
|
||||
}.Validate(ctx)
|
||||
|
||||
var stream *zfs.ReadCloserCopier
|
||||
var stream *zfs.SendStream
|
||||
if err == nil {
|
||||
stream, err = zfs.ZFSSend(ctx, sendArgs) // no shadow
|
||||
if err == nil {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -38,7 +39,7 @@ type Sender interface {
|
||||
// If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before
|
||||
// any next call to the parent github.com/zrepl/zrepl/replication.Endpoint.
|
||||
// If the send request is for dry run the io.ReadCloser will be nil
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error)
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error)
|
||||
SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error)
|
||||
ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error)
|
||||
}
|
||||
@ -47,7 +48,7 @@ type Receiver interface {
|
||||
Endpoint
|
||||
// Receive sends r and sendStream (the latter containing a ZFS send stream)
|
||||
// to the parent github.com/zrepl/zrepl/replication.Endpoint.
|
||||
Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error)
|
||||
Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error)
|
||||
}
|
||||
|
||||
type PlannerPolicy struct {
|
||||
@ -162,7 +163,7 @@ type Step struct {
|
||||
|
||||
// byteCounter is nil initially, and set later in Step.doReplication
|
||||
// => concurrent read of that pointer from Step.ReportInfo must be protected
|
||||
byteCounter bytecounter.StreamCopier
|
||||
byteCounter bytecounter.ReadCloser
|
||||
byteCounterMtx chainlock.L
|
||||
}
|
||||
|
||||
@ -606,19 +607,19 @@ func (s *Step) doReplication(ctx context.Context) error {
|
||||
sr := s.buildSendRequest(false)
|
||||
|
||||
log.Debug("initiate send request")
|
||||
sres, sstreamCopier, err := s.sender.Send(ctx, sr)
|
||||
sres, stream, err := s.sender.Send(ctx, sr)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("send request failed")
|
||||
return err
|
||||
}
|
||||
if sstreamCopier == nil {
|
||||
if stream == nil {
|
||||
err := errors.New("send request did not return a stream, broken endpoint implementation")
|
||||
return err
|
||||
}
|
||||
defer sstreamCopier.Close()
|
||||
defer stream.Close()
|
||||
|
||||
// Install a byte counter to track progress + for status report
|
||||
byteCountingStream := bytecounter.NewStreamCopier(sstreamCopier)
|
||||
byteCountingStream := bytecounter.NewReadCloser(stream)
|
||||
s.byteCounterMtx.Lock()
|
||||
s.byteCounter = byteCountingStream
|
||||
s.byteCounterMtx.Unlock()
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
@ -11,7 +12,6 @@ import (
|
||||
"github.com/zrepl/zrepl/replication/logic/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/stream"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@ -26,7 +26,7 @@ func NewClient(connecter transport.Connecter, log Logger) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error {
|
||||
func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, stream io.ReadCloser) error {
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, memErr := buf.WriteString(endpoint)
|
||||
@ -46,8 +46,8 @@ func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, r
|
||||
return err
|
||||
}
|
||||
|
||||
if streamCopier != nil {
|
||||
return conn.SendStream(ctx, streamCopier, ZFSStream)
|
||||
if stream != nil {
|
||||
return conn.SendStream(ctx, stream, ZFSStream)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
@ -109,7 +109,7 @@ func (c *Client) putWire(conn *stream.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
conn, err := c.getWire(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@ -130,16 +130,19 @@ func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, z
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var copier zfs.StreamCopier = nil
|
||||
var stream io.ReadCloser
|
||||
if !req.DryRun {
|
||||
putWireOnReturn = false
|
||||
copier = &streamCopier{streamConn: conn, closeStreamOnClose: true}
|
||||
stream, err = conn.ReadStream(ZFSStream, true) // no shadow
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &res, copier, nil
|
||||
return &res, stream, nil
|
||||
}
|
||||
|
||||
func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) {
|
||||
|
||||
defer c.log.Debug("ReqRecv returns")
|
||||
conn, err := c.getWire(ctx)
|
||||
@ -166,7 +169,7 @@ func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier
|
||||
|
||||
sendErrChan := make(chan error)
|
||||
go func() {
|
||||
if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil {
|
||||
if err := c.send(ctx, conn, EndpointRecv, req, stream); err != nil {
|
||||
sendErrChan <- err
|
||||
} else {
|
||||
sendErrChan <- nil
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
@ -11,7 +12,6 @@ import (
|
||||
"github.com/zrepl/zrepl/replication/logic/pdu"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/stream"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
// WireInterceptor has a chance to exchange the context and connection on each client connection.
|
||||
@ -21,11 +21,11 @@ type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (con
|
||||
type Handler interface {
|
||||
// Send handles a SendRequest.
|
||||
// The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run.
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error)
|
||||
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error)
|
||||
// Receive handles a ReceiveRequest.
|
||||
// It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout
|
||||
// configured in ServerConfig.Shared.IdleConnTimeout.
|
||||
Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error)
|
||||
Receive(ctx context.Context, r *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error)
|
||||
// PingDataconn handles a PingReq
|
||||
PingDataconn(ctx context.Context, r *pdu.PingReq) (*pdu.PingRes, error)
|
||||
}
|
||||
@ -111,7 +111,7 @@ func (s *Server) serveConn(nc *transport.AuthConn) {
|
||||
s.log.WithField("endpoint", endpoint).Debug("calling handler")
|
||||
|
||||
var res proto.Message
|
||||
var sendStream zfs.StreamCopier
|
||||
var sendStream io.ReadCloser
|
||||
var handlerErr error
|
||||
switch endpoint {
|
||||
case EndpointSend:
|
||||
@ -127,7 +127,12 @@ func (s *Server) serveConn(nc *transport.AuthConn) {
|
||||
s.log.WithError(err).Error("cannot unmarshal receive request")
|
||||
return
|
||||
}
|
||||
res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING
|
||||
stream, err := c.ReadStream(ZFSStream, false)
|
||||
if err != nil {
|
||||
s.log.WithError(err).Error("cannot open stream in receive request")
|
||||
return
|
||||
}
|
||||
res, handlerErr = s.h.Receive(ctx, &req, stream) // SHADOWING
|
||||
case EndpointPing:
|
||||
var req pdu.PingReq
|
||||
if err := proto.Unmarshal(reqStructured, &req); err != nil {
|
||||
|
@ -1,12 +1,7 @@
|
||||
package dataconn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/stream"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -39,33 +34,3 @@ const (
|
||||
responseHeaderHandlerOk = "HANDLER OK\n"
|
||||
responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n"
|
||||
)
|
||||
|
||||
type streamCopier struct {
|
||||
mtx sync.Mutex
|
||||
used bool
|
||||
streamConn *stream.Conn
|
||||
closeStreamOnClose bool
|
||||
}
|
||||
|
||||
// WriteStreamTo implements zfs.StreamCopier
|
||||
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
if s.used {
|
||||
panic("streamCopier used multiple times")
|
||||
}
|
||||
s.used = true
|
||||
return s.streamConn.ReadStreamInto(w, ZFSStream)
|
||||
}
|
||||
|
||||
// Close implements zfs.StreamCopier
|
||||
func (s *streamCopier) Close() error {
|
||||
// only record the close here, what we do actually depends on whether
|
||||
// the streamCopier is instantiated server-side or client-side
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
if s.closeStreamOnClose {
|
||||
return s.streamConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -29,7 +29,6 @@ import (
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/util/devnoop"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
func orDie(err error) {
|
||||
@ -42,23 +41,9 @@ type readerStreamCopier struct{ io.Reader }
|
||||
|
||||
func (readerStreamCopier) Close() error { return nil }
|
||||
|
||||
type readerStreamCopierErr struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (readerStreamCopierErr) IsReadError() bool { return false }
|
||||
func (readerStreamCopierErr) IsWriteError() bool { return true }
|
||||
|
||||
func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
|
||||
var buf [1 << 21]byte
|
||||
_, err := io.CopyBuffer(w, c.Reader, buf[:])
|
||||
// always assume write error
|
||||
return readerStreamCopierErr{err}
|
||||
}
|
||||
|
||||
type devNullHandler struct{}
|
||||
|
||||
func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
var res pdu.SendRes
|
||||
if args.devnoopReader {
|
||||
return &res, readerStreamCopier{devnoop.Get()}, nil
|
||||
@ -67,12 +52,12 @@ func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, z
|
||||
}
|
||||
}
|
||||
|
||||
func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) {
|
||||
var out io.Writer = os.Stdout
|
||||
if args.devnoopWriter {
|
||||
out = devnoop.Get()
|
||||
}
|
||||
err := stream.WriteStreamTo(out)
|
||||
_, err := io.Copy(out, stream)
|
||||
var res pdu.ReceiveRes
|
||||
return &res, err
|
||||
}
|
||||
@ -172,7 +157,7 @@ func client() {
|
||||
req := pdu.SendReq{}
|
||||
_, stream, err := client.ReqSend(ctx, &req)
|
||||
orDie(err)
|
||||
err = stream.WriteStreamTo(os.Stdout)
|
||||
_, err = io.Copy(os.Stdout, stream)
|
||||
orDie(err)
|
||||
case "recv":
|
||||
var r io.Reader = os.Stdin
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/base2bufpool"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/frameconn"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type Logger = logger.Logger
|
||||
@ -198,7 +197,7 @@ func (e ReadStreamError) Temporary() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var _ zfs.StreamCopierError = &ReadStreamError{}
|
||||
var _ net.Error = &ReadStreamError{}
|
||||
|
||||
func (e ReadStreamError) IsReadError() bool {
|
||||
return e.Kind != ReadStreamErrorKindWrite
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
|
||||
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
@ -40,15 +39,7 @@ type Conn struct {
|
||||
|
||||
var readMessageSentinel = fmt.Errorf("read stream complete")
|
||||
|
||||
type writeStreamToErrorUnknownState struct{}
|
||||
|
||||
func (e writeStreamToErrorUnknownState) Error() string {
|
||||
return "dataconn read stream: connection is in unknown state"
|
||||
}
|
||||
|
||||
func (e writeStreamToErrorUnknownState) IsReadError() bool { return true }
|
||||
|
||||
func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false }
|
||||
var errWriteStreamToErrorUnknownState = fmt.Errorf("dataconn read stream: connection is in unknown state")
|
||||
|
||||
func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn {
|
||||
hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout)
|
||||
@ -123,14 +114,28 @@ func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameTyp
|
||||
}
|
||||
}
|
||||
|
||||
type StreamReader struct {
|
||||
*io.PipeReader
|
||||
conn *Conn
|
||||
closeConnOnClose bool
|
||||
}
|
||||
|
||||
func (r *StreamReader) Close() error {
|
||||
err := r.PipeReader.Close()
|
||||
if r.closeConnOnClose {
|
||||
r.conn.Close() // TODO error logging
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteStreamTo reads a stream from Conn and writes it to w.
|
||||
func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopierError) {
|
||||
func (c *Conn) ReadStream(frameType uint32, closeConnOnClose bool) (_ *StreamReader, err error) {
|
||||
|
||||
// if we are closed while writing, return that as an error
|
||||
if closeGuard, cse := c.closeState.RWEntry(); cse != nil {
|
||||
return cse
|
||||
return nil, cse
|
||||
} else {
|
||||
defer func(err *zfs.StreamCopierError) {
|
||||
defer func(err *error) {
|
||||
if closed := closeGuard.RWExit(); closed != nil {
|
||||
*err = closed
|
||||
}
|
||||
@ -138,18 +143,23 @@ func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopi
|
||||
}
|
||||
|
||||
c.readMtx.Lock()
|
||||
defer c.readMtx.Unlock()
|
||||
if !c.readClean {
|
||||
return writeStreamToErrorUnknownState{}
|
||||
return nil, errWriteStreamToErrorUnknownState
|
||||
}
|
||||
var rse *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
|
||||
c.readClean = isConnCleanAfterRead(rse)
|
||||
|
||||
// https://golang.org/doc/faq#nil_error
|
||||
if rse == nil {
|
||||
return nil
|
||||
}
|
||||
return rse
|
||||
r, w := io.Pipe()
|
||||
go func() {
|
||||
defer c.readMtx.Unlock()
|
||||
var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
|
||||
if err != nil {
|
||||
_ = w.CloseWithError(err) // doc guarantees that error will always be nil
|
||||
} else {
|
||||
w.Close()
|
||||
}
|
||||
c.readClean = isConnCleanAfterRead(err)
|
||||
}()
|
||||
|
||||
return &StreamReader{PipeReader: r, conn: c, closeConnOnClose: closeConnOnClose}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) (err error) {
|
||||
@ -178,7 +188,7 @@ func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameTyp
|
||||
return errConn
|
||||
}
|
||||
|
||||
func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) (err error) {
|
||||
func (c *Conn) SendStream(ctx context.Context, stream io.ReadCloser, frameType uint32) (err error) {
|
||||
|
||||
// if we are closed while reading, return that as an error
|
||||
if closeGuard, cse := c.closeState.RWEntry(); cse != nil {
|
||||
@ -197,49 +207,17 @@ func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType u
|
||||
return fmt.Errorf("dataconn send stream: connection is in unknown state")
|
||||
}
|
||||
|
||||
// avoid io.Pipe if zfs.StreamCopier is an io.Reader
|
||||
var r io.Reader
|
||||
var w *io.PipeWriter
|
||||
streamCopierErrChan := make(chan zfs.StreamCopierError, 1)
|
||||
if reader, ok := src.(io.Reader); ok {
|
||||
r = reader
|
||||
streamCopierErrChan <- nil
|
||||
close(streamCopierErrChan)
|
||||
} else {
|
||||
r, w = io.Pipe()
|
||||
go func() {
|
||||
streamCopierErrChan <- src.WriteStreamTo(w)
|
||||
w.Close()
|
||||
}()
|
||||
}
|
||||
errStream, errConn := writeStream(ctx, c.hc, stream, frameType)
|
||||
|
||||
type writeStreamRes struct {
|
||||
errStream, errConn error
|
||||
}
|
||||
writeStreamErrChan := make(chan writeStreamRes, 1)
|
||||
go func() {
|
||||
var res writeStreamRes
|
||||
res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType)
|
||||
if w != nil {
|
||||
_ = w.CloseWithError(res.errStream) // always returns nil
|
||||
}
|
||||
writeStreamErrChan <- res
|
||||
}()
|
||||
c.writeClean = isConnCleanAfterWrite(errConn) // TODO correct?
|
||||
|
||||
writeRes := <-writeStreamErrChan
|
||||
streamCopierErr := <-streamCopierErrChan
|
||||
c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct?
|
||||
if streamCopierErr != nil && streamCopierErr.IsReadError() {
|
||||
return streamCopierErr // something on our side is bad
|
||||
} else {
|
||||
if writeRes.errStream != nil {
|
||||
return writeRes.errStream
|
||||
} else if writeRes.errConn != nil {
|
||||
return writeRes.errConn
|
||||
}
|
||||
// TODO combined error?
|
||||
return streamCopierErr
|
||||
if errStream != nil {
|
||||
return errStream
|
||||
} else if errConn != nil {
|
||||
return errConn
|
||||
}
|
||||
// TODO combined error?
|
||||
return nil
|
||||
}
|
||||
|
||||
type closeState struct {
|
||||
@ -248,17 +226,13 @@ type closeState struct {
|
||||
|
||||
type closeStateErrConnectionClosed struct{}
|
||||
|
||||
var _ zfs.StreamCopierError = (*closeStateErrConnectionClosed)(nil)
|
||||
var _ error = (*closeStateErrConnectionClosed)(nil)
|
||||
var _ net.Error = (*closeStateErrConnectionClosed)(nil)
|
||||
|
||||
func (e *closeStateErrConnectionClosed) Error() string {
|
||||
return "connection closed"
|
||||
}
|
||||
func (e *closeStateErrConnectionClosed) IsReadError() bool { return true }
|
||||
func (e *closeStateErrConnectionClosed) IsWriteError() bool { return true }
|
||||
func (e *closeStateErrConnectionClosed) Timeout() bool { return false }
|
||||
func (e *closeStateErrConnectionClosed) Temporary() bool { return false }
|
||||
func (e *closeStateErrConnectionClosed) Timeout() bool { return false }
|
||||
func (e *closeStateErrConnectionClosed) Temporary() bool { return false }
|
||||
|
||||
func (s *closeState) CloseEntry() error {
|
||||
firstCloser := atomic.AddUint32(&s.closeCount, 1) == 1
|
||||
@ -273,7 +247,7 @@ type closeStateEntry struct {
|
||||
entryCount uint32
|
||||
}
|
||||
|
||||
func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) {
|
||||
func (s *closeState) RWEntry() (e *closeStateEntry, err net.Error) {
|
||||
entry := &closeStateEntry{s, atomic.LoadUint32(&s.closeCount)}
|
||||
if entry.entryCount > 0 {
|
||||
return nil, &closeStateErrConnectionClosed{}
|
||||
@ -281,7 +255,7 @@ func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) {
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (e *closeStateEntry) RWExit() zfs.StreamCopierError {
|
||||
func (e *closeStateEntry) RWExit() net.Error {
|
||||
if atomic.LoadUint32(&e.entryCount) == e.entryCount {
|
||||
// no calls to Close() while running rw operation
|
||||
return nil
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -20,7 +21,6 @@ import (
|
||||
"github.com/zrepl/zrepl/rpc/versionhandshake"
|
||||
"github.com/zrepl/zrepl/transport"
|
||||
"github.com/zrepl/zrepl/util/envconst"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
// Client implements the active side of a replication setup.
|
||||
@ -82,22 +82,22 @@ func (c *Client) Close() {
|
||||
|
||||
// callers must ensure that the returned io.ReadCloser is closed
|
||||
// TODO expose dataClient interface to the outside world
|
||||
func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) {
|
||||
func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
|
||||
// TODO the returned sendStream may return a read error created by the remote side
|
||||
res, streamCopier, err := c.dataClient.ReqSend(ctx, r)
|
||||
res, stream, err := c.dataClient.ReqSend(ctx, r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if streamCopier == nil {
|
||||
if stream == nil {
|
||||
return res, nil, nil
|
||||
}
|
||||
|
||||
return res, streamCopier, nil
|
||||
return res, stream, nil
|
||||
|
||||
}
|
||||
|
||||
func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) {
|
||||
return c.dataClient.ReqRecv(ctx, req, streamCopier)
|
||||
func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) {
|
||||
return c.dataClient.ReqRecv(ctx, req, stream)
|
||||
}
|
||||
|
||||
func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {
|
||||
|
39
util/bytecounter/bytecounter_readcloser.go
Normal file
39
util/bytecounter/bytecounter_readcloser.go
Normal file
@ -0,0 +1,39 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ReadCloser wraps an io.ReadCloser, reimplementing
|
||||
// its interface and counting the bytes written to during copying.
|
||||
type ReadCloser interface {
|
||||
io.ReadCloser
|
||||
Count() int64
|
||||
}
|
||||
|
||||
// NewReadCloser wraps rc.
|
||||
func NewReadCloser(rc io.ReadCloser) ReadCloser {
|
||||
return &readCloser{rc, 0}
|
||||
}
|
||||
|
||||
type readCloser struct {
|
||||
rc io.ReadCloser
|
||||
count int64
|
||||
}
|
||||
|
||||
func (r *readCloser) Count() int64 {
|
||||
return atomic.LoadInt64(&r.count)
|
||||
}
|
||||
|
||||
var _ io.ReadCloser = &readCloser{}
|
||||
|
||||
func (r *readCloser) Close() error {
|
||||
return r.rc.Close()
|
||||
}
|
||||
|
||||
func (r *readCloser) Read(p []byte) (int, error) {
|
||||
n, err := r.rc.Read(p)
|
||||
atomic.AddInt64(&r.count, int64(n))
|
||||
return n, err
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ByteCounterReader struct {
|
||||
reader io.ReadCloser
|
||||
|
||||
// called & accessed synchronously during Read, no external access
|
||||
cb func(full int64)
|
||||
cbEvery time.Duration
|
||||
lastCbAt time.Time
|
||||
|
||||
// set atomically because it may be read by multiple threads
|
||||
bytes int64
|
||||
}
|
||||
|
||||
func NewByteCounterReader(reader io.ReadCloser) *ByteCounterReader {
|
||||
return &ByteCounterReader{
|
||||
reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *ByteCounterReader) SetCallback(every time.Duration, cb func(full int64)) {
|
||||
b.cbEvery = every
|
||||
b.cb = cb
|
||||
}
|
||||
|
||||
func (b *ByteCounterReader) Close() error {
|
||||
return b.reader.Close()
|
||||
}
|
||||
|
||||
func (b *ByteCounterReader) Read(p []byte) (n int, err error) {
|
||||
n, err = b.reader.Read(p)
|
||||
full := atomic.AddInt64(&b.bytes, int64(n))
|
||||
now := time.Now()
|
||||
if b.cb != nil && now.Sub(b.lastCbAt) > b.cbEvery {
|
||||
b.cb(full)
|
||||
b.lastCbAt = now
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *ByteCounterReader) Bytes() int64 {
|
||||
return atomic.LoadInt64(&b.bytes)
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
// StreamCopier wraps a zfs.StreamCopier, reimplementing
|
||||
// its interface and counting the bytes written to during copying.
|
||||
type StreamCopier interface {
|
||||
zfs.StreamCopier
|
||||
Count() int64
|
||||
}
|
||||
|
||||
// NewStreamCopier wraps sc into a StreamCopier.
|
||||
// If sc is io.Reader, it is guaranteed that the returned StreamCopier
|
||||
// implements that interface, too.
|
||||
func NewStreamCopier(sc zfs.StreamCopier) StreamCopier {
|
||||
bsc := &streamCopier{sc, 0}
|
||||
if scr, ok := sc.(io.Reader); ok {
|
||||
return streamCopierAndReader{bsc, scr}
|
||||
} else {
|
||||
return bsc
|
||||
}
|
||||
}
|
||||
|
||||
type streamCopier struct {
|
||||
sc zfs.StreamCopier
|
||||
count int64
|
||||
}
|
||||
|
||||
// proxy writer used by streamCopier
|
||||
type streamCopierWriter struct {
|
||||
parent *streamCopier
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
func (w streamCopierWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = w.w.Write(p)
|
||||
atomic.AddInt64(&w.parent.count, int64(n))
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamCopier) Count() int64 {
|
||||
return atomic.LoadInt64(&s.count)
|
||||
}
|
||||
|
||||
var _ zfs.StreamCopier = &streamCopier{}
|
||||
|
||||
func (s streamCopier) Close() error {
|
||||
return s.sc.Close()
|
||||
}
|
||||
|
||||
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
|
||||
ww := streamCopierWriter{s, w}
|
||||
return s.sc.WriteStreamTo(ww)
|
||||
}
|
||||
|
||||
// a streamCopier whose underlying sc is an io.Reader
|
||||
type streamCopierAndReader struct {
|
||||
*streamCopier
|
||||
asReader io.Reader
|
||||
}
|
||||
|
||||
func (scr streamCopierAndReader) Read(p []byte) (int, error) {
|
||||
n, err := scr.asReader.Read(p)
|
||||
atomic.AddInt64(&scr.streamCopier.count, int64(n))
|
||||
return n, err
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
package bytecounter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type mockStreamCopierAndReader struct {
|
||||
zfs.StreamCopier // to satisfy interface
|
||||
reads int
|
||||
}
|
||||
|
||||
func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) {
|
||||
r.reads++
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
var _ io.Reader = &mockStreamCopierAndReader{}
|
||||
|
||||
func TestNewStreamCopierReexportsReader(t *testing.T) {
|
||||
mock := &mockStreamCopierAndReader{}
|
||||
x := NewStreamCopier(mock)
|
||||
|
||||
r, ok := x.(io.Reader)
|
||||
if !ok {
|
||||
t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x)
|
||||
}
|
||||
|
||||
var buf [23]byte
|
||||
n, err := r.Read(buf[:])
|
||||
assert.True(t, mock.reads == 1)
|
||||
assert.True(t, n == len(buf))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, x.Count() == 23)
|
||||
}
|
152
zfs/zfs.go
152
zfs/zfs.go
@ -354,63 +354,6 @@ func (a ZFSSendArgsUnvalidated) buildCommonSendArgs() ([]string, error) {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
type ReadCloserCopier struct {
|
||||
recorder readErrRecorder
|
||||
}
|
||||
|
||||
type readErrRecorder struct {
|
||||
io.ReadCloser
|
||||
readErr error
|
||||
}
|
||||
|
||||
type sendStreamCopierError struct {
|
||||
isReadErr bool // if false, it's a write error
|
||||
err error
|
||||
}
|
||||
|
||||
func (e sendStreamCopierError) Error() string {
|
||||
if e.isReadErr {
|
||||
return fmt.Sprintf("stream: read error: %s", e.err)
|
||||
} else {
|
||||
return fmt.Sprintf("stream: writer error: %s", e.err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e sendStreamCopierError) IsReadError() bool { return e.isReadErr }
|
||||
func (e sendStreamCopierError) IsWriteError() bool { return !e.isReadErr }
|
||||
|
||||
func (r *readErrRecorder) Read(p []byte) (n int, err error) {
|
||||
n, err = r.ReadCloser.Read(p)
|
||||
r.readErr = err
|
||||
return n, err
|
||||
}
|
||||
|
||||
func NewReadCloserCopier(stream io.ReadCloser) *ReadCloserCopier {
|
||||
return &ReadCloserCopier{recorder: readErrRecorder{stream, nil}}
|
||||
}
|
||||
|
||||
func (c *ReadCloserCopier) WriteStreamTo(w io.Writer) StreamCopierError {
|
||||
debug("sendStreamCopier.WriteStreamTo: begin")
|
||||
_, err := io.Copy(w, &c.recorder)
|
||||
debug("sendStreamCopier.WriteStreamTo: copy done")
|
||||
if err != nil {
|
||||
if c.recorder.readErr != nil {
|
||||
return sendStreamCopierError{isReadErr: true, err: c.recorder.readErr}
|
||||
} else {
|
||||
return sendStreamCopierError{isReadErr: false, err: err}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ReadCloserCopier) Read(p []byte) (n int, err error) {
|
||||
return c.recorder.Read(p)
|
||||
}
|
||||
|
||||
func (c *ReadCloserCopier) Close() error {
|
||||
return c.recorder.ReadCloser.Close()
|
||||
}
|
||||
|
||||
func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) {
|
||||
if capacity <= 0 {
|
||||
panic(fmt.Sprintf("capacity must be positive %v", capacity))
|
||||
@ -423,7 +366,7 @@ func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) {
|
||||
return stdoutReader, stdoutWriter, nil
|
||||
}
|
||||
|
||||
type sendStream struct {
|
||||
type SendStream struct {
|
||||
cmd *zfscmd.Cmd
|
||||
kill context.CancelFunc
|
||||
|
||||
@ -433,7 +376,7 @@ type sendStream struct {
|
||||
opErr error
|
||||
}
|
||||
|
||||
func (s *sendStream) Read(p []byte) (n int, err error) {
|
||||
func (s *SendStream) Read(p []byte) (n int, err error) {
|
||||
s.closeMtx.Lock()
|
||||
opErr := s.opErr
|
||||
s.closeMtx.Unlock()
|
||||
@ -454,12 +397,12 @@ func (s *sendStream) Read(p []byte) (n int, err error) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *sendStream) Close() error {
|
||||
func (s *SendStream) Close() error {
|
||||
debug("sendStream: close called")
|
||||
return s.killAndWait(nil)
|
||||
}
|
||||
|
||||
func (s *sendStream) killAndWait(precedingReadErr error) error {
|
||||
func (s *SendStream) killAndWait(precedingReadErr error) error {
|
||||
|
||||
debug("sendStream: killAndWait enter")
|
||||
defer debug("sendStream: killAndWait leave")
|
||||
@ -830,7 +773,7 @@ var ErrEncryptedSendNotSupported = fmt.Errorf("raw sends which are required for
|
||||
// (if from is "" a full ZFS send is done)
|
||||
//
|
||||
// Returns ErrEncryptedSendNotSupported if encrypted send is requested but not supported by CLI
|
||||
func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*ReadCloserCopier, error) {
|
||||
func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*SendStream, error) {
|
||||
|
||||
args := make([]string, 0)
|
||||
args = append(args, "send")
|
||||
@ -879,14 +822,14 @@ func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*ReadCloserCop
|
||||
// close our writing-end of the pipe so that we don't wait for ourselves when reading from the reading end
|
||||
stdoutWriter.Close()
|
||||
|
||||
stream := &sendStream{
|
||||
stream := &SendStream{
|
||||
cmd: cmd,
|
||||
kill: cancel,
|
||||
stdoutReader: stdoutReader,
|
||||
stderrBuf: stderrBuf,
|
||||
}
|
||||
|
||||
return NewReadCloserCopier(stream), nil
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
type DrySendType string
|
||||
@ -1025,24 +968,6 @@ func ZFSSendDry(ctx context.Context, sendArgs ZFSSendArgsValidated) (_ *DrySendI
|
||||
return &si, nil
|
||||
}
|
||||
|
||||
type StreamCopierError interface {
|
||||
error
|
||||
IsReadError() bool
|
||||
IsWriteError() bool
|
||||
}
|
||||
|
||||
type StreamCopier interface {
|
||||
// WriteStreamTo writes the stream represented by this StreamCopier
|
||||
// to the given io.Writer.
|
||||
WriteStreamTo(w io.Writer) StreamCopierError
|
||||
// Close must be called as soon as it is clear that no more data will
|
||||
// be read from the StreamCopier.
|
||||
// If StreamCopier gets its data from a connection, it might hold
|
||||
// a lock on the connection until Close is called. Only closing ensures
|
||||
// that the connection can be used afterwards.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type RecvOptions struct {
|
||||
// Rollback to the oldest snapshot, destroy it, then perform `recv -F`.
|
||||
// Note that this doesn't change property values, i.e. an existing local property value will be kept.
|
||||
@ -1067,7 +992,9 @@ func (e *ErrRecvResumeNotSupported) Error() string {
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier StreamCopier, opts RecvOptions) (err error) {
|
||||
const RecvStderrBufSiz = 1 << 15
|
||||
|
||||
func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, stream io.ReadCloser, opts RecvOptions) (err error) {
|
||||
|
||||
if err := v.ValidateInMemory(fs); err != nil {
|
||||
return errors.Wrap(err, "invalid version")
|
||||
@ -1134,7 +1061,7 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
|
||||
// cannot receive new filesystem stream: invalid backup stream
|
||||
stdout := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||
|
||||
stderr := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||
stderr := bytes.NewBuffer(make([]byte, 0, RecvStderrBufSiz))
|
||||
|
||||
stdin, stdinWriter, err := pipeWithCapacityHint(ZFSRecvPipeCapacityHint)
|
||||
if err != nil {
|
||||
@ -1162,9 +1089,10 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
|
||||
|
||||
debug("started")
|
||||
|
||||
copierErrChan := make(chan StreamCopierError)
|
||||
copierErrChan := make(chan error)
|
||||
go func() {
|
||||
copierErrChan <- streamCopier.WriteStreamTo(stdinWriter)
|
||||
_, err := io.Copy(stdinWriter, stream)
|
||||
copierErrChan <- err
|
||||
stdinWriter.Close()
|
||||
}()
|
||||
waitErrChan := make(chan error)
|
||||
@ -1173,6 +1101,10 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
|
||||
if err = cmd.Wait(); err != nil {
|
||||
if rtErr := tryRecvErrorWithResumeToken(ctx, stderr.String()); rtErr != nil {
|
||||
waitErrChan <- rtErr
|
||||
} else if owErr := tryRecvDestroyOrOverwriteEncryptedErr(stderr.Bytes()); owErr != nil {
|
||||
waitErrChan <- owErr
|
||||
} else if readErr := tryRecvCannotReadFromStreamErr(stderr.Bytes()); readErr != nil {
|
||||
waitErrChan <- readErr
|
||||
} else {
|
||||
waitErrChan <- &ZFSError{
|
||||
Stderr: stderr.Bytes(),
|
||||
@ -1183,22 +1115,23 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
|
||||
}
|
||||
}()
|
||||
|
||||
// streamCopier always fails before or simultaneously with Wait
|
||||
// thus receive from it first
|
||||
copierErr := <-copierErrChan
|
||||
debug("copierErr: %T %s", copierErr, copierErr)
|
||||
if copierErr != nil {
|
||||
debug("killing zfs recv command after copierErr")
|
||||
cancelCmd()
|
||||
}
|
||||
|
||||
waitErr := <-waitErrChan
|
||||
debug("waitErr: %T %s", waitErr, waitErr)
|
||||
|
||||
if copierErr == nil && waitErr == nil {
|
||||
return nil
|
||||
} else if waitErr != nil && (copierErr == nil || copierErr.IsWriteError()) {
|
||||
return waitErr // has more interesting info in that case
|
||||
} else if _, isReadErr := waitErr.(*RecvCannotReadFromStreamErr); isReadErr {
|
||||
return copierErr // likely network error reading from stream
|
||||
} else {
|
||||
return waitErr // almost always more interesting info. NOTE: do not wrap!
|
||||
}
|
||||
return copierErr // if it's not a write error, the copier error is more interesting
|
||||
}
|
||||
|
||||
type RecvFailedWithResumeTokenErr struct {
|
||||
@ -1228,6 +1161,43 @@ func (e *RecvFailedWithResumeTokenErr) Error() string {
|
||||
return fmt.Sprintf("receive failed, resume token available: %s\n%#v", e.ResumeTokenRaw, e.ResumeTokenParsed)
|
||||
}
|
||||
|
||||
type RecvDestroyOrOverwriteEncryptedErr struct {
|
||||
Msg string
|
||||
}
|
||||
|
||||
func (e *RecvDestroyOrOverwriteEncryptedErr) Error() string {
|
||||
return e.Msg
|
||||
}
|
||||
|
||||
var recvDestroyOrOverwriteEncryptedErrRe = regexp.MustCompile(`^(cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one)`)
|
||||
|
||||
func tryRecvDestroyOrOverwriteEncryptedErr(stderr []byte) *RecvDestroyOrOverwriteEncryptedErr {
|
||||
debug("tryRecvDestroyOrOverwriteEncryptedErr: %v", stderr)
|
||||
m := recvDestroyOrOverwriteEncryptedErrRe.FindSubmatch(stderr)
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &RecvDestroyOrOverwriteEncryptedErr{Msg: string(m[1])}
|
||||
}
|
||||
|
||||
type RecvCannotReadFromStreamErr struct {
|
||||
Msg string
|
||||
}
|
||||
|
||||
func (e *RecvCannotReadFromStreamErr) Error() string {
|
||||
return e.Msg
|
||||
}
|
||||
|
||||
var reRecvCannotReadFromStreamErr = regexp.MustCompile(`^(cannot receive: failed to read from stream)$`)
|
||||
|
||||
func tryRecvCannotReadFromStreamErr(stderr []byte) *RecvCannotReadFromStreamErr {
|
||||
m := reRecvCannotReadFromStreamErr.FindSubmatch(stderr)
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &RecvCannotReadFromStreamErr{Msg: string(m[1])}
|
||||
}
|
||||
|
||||
type ClearResumeTokenError struct {
|
||||
ZFSOutput []byte
|
||||
CmdError error
|
||||
|
@ -2,9 +2,11 @@ package zfs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestZFSListHandlesProducesZFSErrorOnNonZeroExit(t *testing.T) {
|
||||
@ -259,3 +261,12 @@ size 10518512
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryRecvDestroyOrOverwriteEncryptedErr(t *testing.T) {
|
||||
msg := "cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one\n"
|
||||
assert.GreaterOrEqual(t, RecvStderrBufSiz, len(msg))
|
||||
|
||||
err := tryRecvDestroyOrOverwriteEncryptedErr([]byte(msg))
|
||||
require.NotNil(t, err)
|
||||
assert.EqualError(t, err, strings.TrimSpace(msg))
|
||||
}
|
||||
|
@ -5,11 +5,14 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zrepl/zrepl/util/circlog"
|
||||
)
|
||||
|
||||
const testBin = "./zfscmd_platform_test.bash"
|
||||
@ -85,5 +88,37 @@ func TestCmdProcessState(t *testing.T) {
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, ee.ProcessState)
|
||||
require.Contains(t, ee.Error(), "killed")
|
||||
}
|
||||
|
||||
func TestSigpipe(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
cmd := CommandContext(ctx, "bash", "-c", "sleep 5; echo invalid input; exit 23")
|
||||
r, w, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
output := circlog.MustNewCircularLog(1 << 20)
|
||||
cmd.SetStdio(Stdio{
|
||||
Stdin: r,
|
||||
Stdout: output,
|
||||
Stderr: output,
|
||||
})
|
||||
err = cmd.Start()
|
||||
require.NoError(t, err)
|
||||
err = r.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// the script doesn't read stdin, but this input is almost certainly smaller than the pipe buffer
|
||||
const LargerThanPipeBuffer = 1 << 21
|
||||
_, err = io.Copy(w, bytes.NewBuffer(bytes.Repeat([]byte("i"), LargerThanPipeBuffer)))
|
||||
// => io.Copy is going to block because the pipe buffer is full and the
|
||||
// script is not reading from it
|
||||
// => the script is going to exit after 5s
|
||||
// => we should expect a broken pipe error from the copier's perspective
|
||||
t.Logf("copy err = %T: %s", err, err)
|
||||
require.NotNil(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "broken pipe"))
|
||||
|
||||
err = cmd.Wait()
|
||||
require.EqualError(t, err, "exit status 23")
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user