diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 4bd769b..0003e82 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -4,6 +4,7 @@ package endpoint import ( "context" "fmt" + "io" "path" "sync" @@ -241,7 +242,7 @@ func sendArgsFromPDUAndValidateExistsAndGetVersion(ctx context.Context, fs strin return version, nil } -func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { +func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { _, err := s.filterCheckFS(r.Filesystem) if err != nil { @@ -339,11 +340,11 @@ func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.St // step holds & replication cursor released / moved forward in s.SendCompleted => s.moveCursorAndReleaseSendHolds - streamCopier, err := zfs.ZFSSend(ctx, sendArgs) + sendStream, err := zfs.ZFSSend(ctx, sendArgs) if err != nil { return nil, nil, errors.Wrap(err, "zfs send failed") } - return res, streamCopier, nil + return res, sendStream, nil } func (p *Sender) SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error) { @@ -476,7 +477,7 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs return &pdu.ReplicationCursorRes{Result: &pdu.ReplicationCursorRes_Guid{Guid: cursor.Guid}}, nil } -func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { +func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, _ io.ReadCloser) (*pdu.ReceiveRes, error) { return nil, fmt.Errorf("sender does not implement Receive()") } @@ -680,13 +681,13 @@ func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq) return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver") } -func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { +func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { return nil, nil, fmt.Errorf("receiver does not implement Send()") } var maxConcurrentZFSRecvSemaphore = semaphore.New(envconst.Int64("ZREPL_ENDPOINT_MAX_CONCURRENT_RECV", 10)) -func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { +func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error) { getLogger(ctx).Debug("incoming Receive") defer receive.Close() diff --git a/platformtest/tests/helpers.go b/platformtest/tests/helpers.go index bb7f516..e1394b0 100644 --- a/platformtest/tests/helpers.go +++ b/platformtest/tests/helpers.go @@ -136,7 +136,7 @@ func makeResumeSituation(ctx *platformtest.Context, src dummySnapshotSituation, return situation } - limitedCopier := zfs.NewReadCloserCopier(limitio.ReadCloser(copier, src.dummyDataLen/2)) + limitedCopier := limitio.ReadCloser(copier, src.dummyDataLen/2) defer limitedCopier.Close() require.NotNil(ctx, sendArgs.To) diff --git a/platformtest/tests/sendArgsValidation.go b/platformtest/tests/sendArgsValidation.go index 91ee6d9..416c046 100644 --- a/platformtest/tests/sendArgsValidation.go +++ b/platformtest/tests/sendArgsValidation.go @@ -35,7 +35,7 @@ func SendArgsValidationEncryptedSendOfUnencryptedDatasetForbidden(ctx *platformt ResumeToken: "", }.Validate(ctx) - var stream *zfs.ReadCloserCopier + var stream *zfs.SendStream if err == nil { stream, err = zfs.ZFSSend(ctx, sendArgs) // no shadow if err == nil { diff --git a/replication/logic/replication_logic.go b/replication/logic/replication_logic.go index 6de7a8a..cf56990 100644 --- a/replication/logic/replication_logic.go +++ b/replication/logic/replication_logic.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "sync" "time" @@ -38,7 +39,7 @@ type Sender interface { // If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before // any next call to the parent github.com/zrepl/zrepl/replication.Endpoint. // If the send request is for dry run the io.ReadCloser will be nil - Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) + Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) } @@ -47,7 +48,7 @@ type Receiver interface { Endpoint // Receive sends r and sendStream (the latter containing a ZFS send stream) // to the parent github.com/zrepl/zrepl/replication.Endpoint. - Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) + Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error) } type PlannerPolicy struct { @@ -162,7 +163,7 @@ type Step struct { // byteCounter is nil initially, and set later in Step.doReplication // => concurrent read of that pointer from Step.ReportInfo must be protected - byteCounter bytecounter.StreamCopier + byteCounter bytecounter.ReadCloser byteCounterMtx chainlock.L } @@ -606,19 +607,19 @@ func (s *Step) doReplication(ctx context.Context) error { sr := s.buildSendRequest(false) log.Debug("initiate send request") - sres, sstreamCopier, err := s.sender.Send(ctx, sr) + sres, stream, err := s.sender.Send(ctx, sr) if err != nil { log.WithError(err).Error("send request failed") return err } - if sstreamCopier == nil { + if stream == nil { err := errors.New("send request did not return a stream, broken endpoint implementation") return err } - defer sstreamCopier.Close() + defer stream.Close() // Install a byte counter to track progress + for status report - byteCountingStream := bytecounter.NewStreamCopier(sstreamCopier) + byteCountingStream := bytecounter.NewReadCloser(stream) s.byteCounterMtx.Lock() s.byteCounter = byteCountingStream s.byteCounterMtx.Unlock() diff --git a/rpc/dataconn/dataconn_client.go b/rpc/dataconn/dataconn_client.go index 6cc049b..4f0cb43 100644 --- a/rpc/dataconn/dataconn_client.go +++ b/rpc/dataconn/dataconn_client.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "strings" "github.com/golang/protobuf/proto" @@ -11,7 +12,6 @@ import ( "github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/rpc/dataconn/stream" "github.com/zrepl/zrepl/transport" - "github.com/zrepl/zrepl/zfs" ) type Client struct { @@ -26,7 +26,7 @@ func NewClient(connecter transport.Connecter, log Logger) *Client { } } -func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error { +func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, stream io.ReadCloser) error { var buf bytes.Buffer _, memErr := buf.WriteString(endpoint) @@ -46,8 +46,8 @@ func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, r return err } - if streamCopier != nil { - return conn.SendStream(ctx, streamCopier, ZFSStream) + if stream != nil { + return conn.SendStream(ctx, stream, ZFSStream) } else { return nil } @@ -109,7 +109,7 @@ func (c *Client) putWire(conn *stream.Conn) { } } -func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { +func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { conn, err := c.getWire(ctx) if err != nil { return nil, nil, err @@ -130,16 +130,19 @@ func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, z return nil, nil, err } - var copier zfs.StreamCopier = nil + var stream io.ReadCloser if !req.DryRun { putWireOnReturn = false - copier = &streamCopier{streamConn: conn, closeStreamOnClose: true} + stream, err = conn.ReadStream(ZFSStream, true) // no shadow + if err != nil { + return nil, nil, err + } } - return &res, copier, nil + return &res, stream, nil } -func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { +func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) { defer c.log.Debug("ReqRecv returns") conn, err := c.getWire(ctx) @@ -166,7 +169,7 @@ func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier sendErrChan := make(chan error) go func() { - if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil { + if err := c.send(ctx, conn, EndpointRecv, req, stream); err != nil { sendErrChan <- err } else { sendErrChan <- nil diff --git a/rpc/dataconn/dataconn_server.go b/rpc/dataconn/dataconn_server.go index fbd1a95..aecaf96 100644 --- a/rpc/dataconn/dataconn_server.go +++ b/rpc/dataconn/dataconn_server.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "github.com/golang/protobuf/proto" @@ -11,7 +12,6 @@ import ( "github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/rpc/dataconn/stream" "github.com/zrepl/zrepl/transport" - "github.com/zrepl/zrepl/zfs" ) // WireInterceptor has a chance to exchange the context and connection on each client connection. @@ -21,11 +21,11 @@ type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (con type Handler interface { // Send handles a SendRequest. // The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run. - Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) + Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) // Receive handles a ReceiveRequest. // It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout // configured in ServerConfig.Shared.IdleConnTimeout. - Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) + Receive(ctx context.Context, r *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error) // PingDataconn handles a PingReq PingDataconn(ctx context.Context, r *pdu.PingReq) (*pdu.PingRes, error) } @@ -111,7 +111,7 @@ func (s *Server) serveConn(nc *transport.AuthConn) { s.log.WithField("endpoint", endpoint).Debug("calling handler") var res proto.Message - var sendStream zfs.StreamCopier + var sendStream io.ReadCloser var handlerErr error switch endpoint { case EndpointSend: @@ -127,7 +127,12 @@ func (s *Server) serveConn(nc *transport.AuthConn) { s.log.WithError(err).Error("cannot unmarshal receive request") return } - res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING + stream, err := c.ReadStream(ZFSStream, false) + if err != nil { + s.log.WithError(err).Error("cannot open stream in receive request") + return + } + res, handlerErr = s.h.Receive(ctx, &req, stream) // SHADOWING case EndpointPing: var req pdu.PingReq if err := proto.Unmarshal(reqStructured, &req); err != nil { diff --git a/rpc/dataconn/dataconn_shared.go b/rpc/dataconn/dataconn_shared.go index 96e2fd8..ab5435f 100644 --- a/rpc/dataconn/dataconn_shared.go +++ b/rpc/dataconn/dataconn_shared.go @@ -1,12 +1,7 @@ package dataconn import ( - "io" - "sync" "time" - - "github.com/zrepl/zrepl/rpc/dataconn/stream" - "github.com/zrepl/zrepl/zfs" ) const ( @@ -39,33 +34,3 @@ const ( responseHeaderHandlerOk = "HANDLER OK\n" responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n" ) - -type streamCopier struct { - mtx sync.Mutex - used bool - streamConn *stream.Conn - closeStreamOnClose bool -} - -// WriteStreamTo implements zfs.StreamCopier -func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { - s.mtx.Lock() - defer s.mtx.Unlock() - if s.used { - panic("streamCopier used multiple times") - } - s.used = true - return s.streamConn.ReadStreamInto(w, ZFSStream) -} - -// Close implements zfs.StreamCopier -func (s *streamCopier) Close() error { - // only record the close here, what we do actually depends on whether - // the streamCopier is instantiated server-side or client-side - s.mtx.Lock() - defer s.mtx.Unlock() - if s.closeStreamOnClose { - return s.streamConn.Close() - } - return nil -} diff --git a/rpc/dataconn/microbenchmark/microbenchmark.go b/rpc/dataconn/microbenchmark/microbenchmark.go index 6a762ed..d8708c6 100644 --- a/rpc/dataconn/microbenchmark/microbenchmark.go +++ b/rpc/dataconn/microbenchmark/microbenchmark.go @@ -29,7 +29,6 @@ import ( "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" "github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/util/devnoop" - "github.com/zrepl/zrepl/zfs" ) func orDie(err error) { @@ -42,23 +41,9 @@ type readerStreamCopier struct{ io.Reader } func (readerStreamCopier) Close() error { return nil } -type readerStreamCopierErr struct { - error -} - -func (readerStreamCopierErr) IsReadError() bool { return false } -func (readerStreamCopierErr) IsWriteError() bool { return true } - -func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { - var buf [1 << 21]byte - _, err := io.CopyBuffer(w, c.Reader, buf[:]) - // always assume write error - return readerStreamCopierErr{err} -} - type devNullHandler struct{} -func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { +func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { var res pdu.SendRes if args.devnoopReader { return &res, readerStreamCopier{devnoop.Get()}, nil @@ -67,12 +52,12 @@ func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, z } } -func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) { +func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) { var out io.Writer = os.Stdout if args.devnoopWriter { out = devnoop.Get() } - err := stream.WriteStreamTo(out) + _, err := io.Copy(out, stream) var res pdu.ReceiveRes return &res, err } @@ -172,7 +157,7 @@ func client() { req := pdu.SendReq{} _, stream, err := client.ReqSend(ctx, &req) orDie(err) - err = stream.WriteStreamTo(os.Stdout) + _, err = io.Copy(os.Stdout, stream) orDie(err) case "recv": var r io.Reader = os.Stdin diff --git a/rpc/dataconn/stream/stream.go b/rpc/dataconn/stream/stream.go index b46d000..b12e854 100644 --- a/rpc/dataconn/stream/stream.go +++ b/rpc/dataconn/stream/stream.go @@ -14,7 +14,6 @@ import ( "github.com/zrepl/zrepl/rpc/dataconn/base2bufpool" "github.com/zrepl/zrepl/rpc/dataconn/frameconn" "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" - "github.com/zrepl/zrepl/zfs" ) type Logger = logger.Logger @@ -198,7 +197,7 @@ func (e ReadStreamError) Temporary() bool { return false } -var _ zfs.StreamCopierError = &ReadStreamError{} +var _ net.Error = &ReadStreamError{} func (e ReadStreamError) IsReadError() bool { return e.Kind != ReadStreamErrorKindWrite diff --git a/rpc/dataconn/stream/stream_conn.go b/rpc/dataconn/stream/stream_conn.go index a7c7879..1ab05b0 100644 --- a/rpc/dataconn/stream/stream_conn.go +++ b/rpc/dataconn/stream/stream_conn.go @@ -14,7 +14,6 @@ import ( "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" - "github.com/zrepl/zrepl/zfs" ) type Conn struct { @@ -40,15 +39,7 @@ type Conn struct { var readMessageSentinel = fmt.Errorf("read stream complete") -type writeStreamToErrorUnknownState struct{} - -func (e writeStreamToErrorUnknownState) Error() string { - return "dataconn read stream: connection is in unknown state" -} - -func (e writeStreamToErrorUnknownState) IsReadError() bool { return true } - -func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false } +var errWriteStreamToErrorUnknownState = fmt.Errorf("dataconn read stream: connection is in unknown state") func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn { hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout) @@ -123,14 +114,28 @@ func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameTyp } } +type StreamReader struct { + *io.PipeReader + conn *Conn + closeConnOnClose bool +} + +func (r *StreamReader) Close() error { + err := r.PipeReader.Close() + if r.closeConnOnClose { + r.conn.Close() // TODO error logging + } + return err +} + // WriteStreamTo reads a stream from Conn and writes it to w. -func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopierError) { +func (c *Conn) ReadStream(frameType uint32, closeConnOnClose bool) (_ *StreamReader, err error) { // if we are closed while writing, return that as an error if closeGuard, cse := c.closeState.RWEntry(); cse != nil { - return cse + return nil, cse } else { - defer func(err *zfs.StreamCopierError) { + defer func(err *error) { if closed := closeGuard.RWExit(); closed != nil { *err = closed } @@ -138,18 +143,23 @@ func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopi } c.readMtx.Lock() - defer c.readMtx.Unlock() if !c.readClean { - return writeStreamToErrorUnknownState{} + return nil, errWriteStreamToErrorUnknownState } - var rse *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType) - c.readClean = isConnCleanAfterRead(rse) - // https://golang.org/doc/faq#nil_error - if rse == nil { - return nil - } - return rse + r, w := io.Pipe() + go func() { + defer c.readMtx.Unlock() + var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType) + if err != nil { + _ = w.CloseWithError(err) // doc guarantees that error will always be nil + } else { + w.Close() + } + c.readClean = isConnCleanAfterRead(err) + }() + + return &StreamReader{PipeReader: r, conn: c, closeConnOnClose: closeConnOnClose}, nil } func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) (err error) { @@ -178,7 +188,7 @@ func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameTyp return errConn } -func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) (err error) { +func (c *Conn) SendStream(ctx context.Context, stream io.ReadCloser, frameType uint32) (err error) { // if we are closed while reading, return that as an error if closeGuard, cse := c.closeState.RWEntry(); cse != nil { @@ -197,49 +207,17 @@ func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType u return fmt.Errorf("dataconn send stream: connection is in unknown state") } - // avoid io.Pipe if zfs.StreamCopier is an io.Reader - var r io.Reader - var w *io.PipeWriter - streamCopierErrChan := make(chan zfs.StreamCopierError, 1) - if reader, ok := src.(io.Reader); ok { - r = reader - streamCopierErrChan <- nil - close(streamCopierErrChan) - } else { - r, w = io.Pipe() - go func() { - streamCopierErrChan <- src.WriteStreamTo(w) - w.Close() - }() - } + errStream, errConn := writeStream(ctx, c.hc, stream, frameType) - type writeStreamRes struct { - errStream, errConn error - } - writeStreamErrChan := make(chan writeStreamRes, 1) - go func() { - var res writeStreamRes - res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType) - if w != nil { - _ = w.CloseWithError(res.errStream) // always returns nil - } - writeStreamErrChan <- res - }() + c.writeClean = isConnCleanAfterWrite(errConn) // TODO correct? - writeRes := <-writeStreamErrChan - streamCopierErr := <-streamCopierErrChan - c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct? - if streamCopierErr != nil && streamCopierErr.IsReadError() { - return streamCopierErr // something on our side is bad - } else { - if writeRes.errStream != nil { - return writeRes.errStream - } else if writeRes.errConn != nil { - return writeRes.errConn - } - // TODO combined error? - return streamCopierErr + if errStream != nil { + return errStream + } else if errConn != nil { + return errConn } + // TODO combined error? + return nil } type closeState struct { @@ -248,17 +226,13 @@ type closeState struct { type closeStateErrConnectionClosed struct{} -var _ zfs.StreamCopierError = (*closeStateErrConnectionClosed)(nil) -var _ error = (*closeStateErrConnectionClosed)(nil) var _ net.Error = (*closeStateErrConnectionClosed)(nil) func (e *closeStateErrConnectionClosed) Error() string { return "connection closed" } -func (e *closeStateErrConnectionClosed) IsReadError() bool { return true } -func (e *closeStateErrConnectionClosed) IsWriteError() bool { return true } -func (e *closeStateErrConnectionClosed) Timeout() bool { return false } -func (e *closeStateErrConnectionClosed) Temporary() bool { return false } +func (e *closeStateErrConnectionClosed) Timeout() bool { return false } +func (e *closeStateErrConnectionClosed) Temporary() bool { return false } func (s *closeState) CloseEntry() error { firstCloser := atomic.AddUint32(&s.closeCount, 1) == 1 @@ -273,7 +247,7 @@ type closeStateEntry struct { entryCount uint32 } -func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) { +func (s *closeState) RWEntry() (e *closeStateEntry, err net.Error) { entry := &closeStateEntry{s, atomic.LoadUint32(&s.closeCount)} if entry.entryCount > 0 { return nil, &closeStateErrConnectionClosed{} @@ -281,7 +255,7 @@ func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) { return entry, nil } -func (e *closeStateEntry) RWExit() zfs.StreamCopierError { +func (e *closeStateEntry) RWExit() net.Error { if atomic.LoadUint32(&e.entryCount) == e.entryCount { // no calls to Close() while running rw operation return nil diff --git a/rpc/rpc_client.go b/rpc/rpc_client.go index bdc81de..5d9d2c8 100644 --- a/rpc/rpc_client.go +++ b/rpc/rpc_client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "sync" "sync/atomic" @@ -20,7 +21,6 @@ import ( "github.com/zrepl/zrepl/rpc/versionhandshake" "github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/util/envconst" - "github.com/zrepl/zrepl/zfs" ) // Client implements the active side of a replication setup. @@ -82,22 +82,22 @@ func (c *Client) Close() { // callers must ensure that the returned io.ReadCloser is closed // TODO expose dataClient interface to the outside world -func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { +func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { // TODO the returned sendStream may return a read error created by the remote side - res, streamCopier, err := c.dataClient.ReqSend(ctx, r) + res, stream, err := c.dataClient.ReqSend(ctx, r) if err != nil { return nil, nil, err } - if streamCopier == nil { + if stream == nil { return res, nil, nil } - return res, streamCopier, nil + return res, stream, nil } -func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { - return c.dataClient.ReqRecv(ctx, req, streamCopier) +func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) { + return c.dataClient.ReqRecv(ctx, req, stream) } func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { diff --git a/util/bytecounter/bytecounter_readcloser.go b/util/bytecounter/bytecounter_readcloser.go new file mode 100644 index 0000000..0ddd238 --- /dev/null +++ b/util/bytecounter/bytecounter_readcloser.go @@ -0,0 +1,39 @@ +package bytecounter + +import ( + "io" + "sync/atomic" +) + +// ReadCloser wraps an io.ReadCloser, reimplementing +// its interface and counting the bytes written to during copying. +type ReadCloser interface { + io.ReadCloser + Count() int64 +} + +// NewReadCloser wraps rc. +func NewReadCloser(rc io.ReadCloser) ReadCloser { + return &readCloser{rc, 0} +} + +type readCloser struct { + rc io.ReadCloser + count int64 +} + +func (r *readCloser) Count() int64 { + return atomic.LoadInt64(&r.count) +} + +var _ io.ReadCloser = &readCloser{} + +func (r *readCloser) Close() error { + return r.rc.Close() +} + +func (r *readCloser) Read(p []byte) (int, error) { + n, err := r.rc.Read(p) + atomic.AddInt64(&r.count, int64(n)) + return n, err +} diff --git a/util/bytecounter/bytecounter_reader.go b/util/bytecounter/bytecounter_reader.go deleted file mode 100644 index fc1f7c5..0000000 --- a/util/bytecounter/bytecounter_reader.go +++ /dev/null @@ -1,49 +0,0 @@ -package bytecounter - -import ( - "io" - "sync/atomic" - "time" -) - -type ByteCounterReader struct { - reader io.ReadCloser - - // called & accessed synchronously during Read, no external access - cb func(full int64) - cbEvery time.Duration - lastCbAt time.Time - - // set atomically because it may be read by multiple threads - bytes int64 -} - -func NewByteCounterReader(reader io.ReadCloser) *ByteCounterReader { - return &ByteCounterReader{ - reader: reader, - } -} - -func (b *ByteCounterReader) SetCallback(every time.Duration, cb func(full int64)) { - b.cbEvery = every - b.cb = cb -} - -func (b *ByteCounterReader) Close() error { - return b.reader.Close() -} - -func (b *ByteCounterReader) Read(p []byte) (n int, err error) { - n, err = b.reader.Read(p) - full := atomic.AddInt64(&b.bytes, int64(n)) - now := time.Now() - if b.cb != nil && now.Sub(b.lastCbAt) > b.cbEvery { - b.cb(full) - b.lastCbAt = now - } - return n, err -} - -func (b *ByteCounterReader) Bytes() int64 { - return atomic.LoadInt64(&b.bytes) -} diff --git a/util/bytecounter/bytecounter_streamcopier.go b/util/bytecounter/bytecounter_streamcopier.go deleted file mode 100644 index 568270b..0000000 --- a/util/bytecounter/bytecounter_streamcopier.go +++ /dev/null @@ -1,71 +0,0 @@ -package bytecounter - -import ( - "io" - "sync/atomic" - - "github.com/zrepl/zrepl/zfs" -) - -// StreamCopier wraps a zfs.StreamCopier, reimplementing -// its interface and counting the bytes written to during copying. -type StreamCopier interface { - zfs.StreamCopier - Count() int64 -} - -// NewStreamCopier wraps sc into a StreamCopier. -// If sc is io.Reader, it is guaranteed that the returned StreamCopier -// implements that interface, too. -func NewStreamCopier(sc zfs.StreamCopier) StreamCopier { - bsc := &streamCopier{sc, 0} - if scr, ok := sc.(io.Reader); ok { - return streamCopierAndReader{bsc, scr} - } else { - return bsc - } -} - -type streamCopier struct { - sc zfs.StreamCopier - count int64 -} - -// proxy writer used by streamCopier -type streamCopierWriter struct { - parent *streamCopier - w io.Writer -} - -func (w streamCopierWriter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - atomic.AddInt64(&w.parent.count, int64(n)) - return -} - -func (s *streamCopier) Count() int64 { - return atomic.LoadInt64(&s.count) -} - -var _ zfs.StreamCopier = &streamCopier{} - -func (s streamCopier) Close() error { - return s.sc.Close() -} - -func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError { - ww := streamCopierWriter{s, w} - return s.sc.WriteStreamTo(ww) -} - -// a streamCopier whose underlying sc is an io.Reader -type streamCopierAndReader struct { - *streamCopier - asReader io.Reader -} - -func (scr streamCopierAndReader) Read(p []byte) (int, error) { - n, err := scr.asReader.Read(p) - atomic.AddInt64(&scr.streamCopier.count, int64(n)) - return n, err -} diff --git a/util/bytecounter/bytecounter_streamcopier_test.go b/util/bytecounter/bytecounter_streamcopier_test.go deleted file mode 100644 index f4f8b1e..0000000 --- a/util/bytecounter/bytecounter_streamcopier_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package bytecounter - -import ( - "io" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/zrepl/zrepl/zfs" -) - -type mockStreamCopierAndReader struct { - zfs.StreamCopier // to satisfy interface - reads int -} - -func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) { - r.reads++ - return len(p), nil -} - -var _ io.Reader = &mockStreamCopierAndReader{} - -func TestNewStreamCopierReexportsReader(t *testing.T) { - mock := &mockStreamCopierAndReader{} - x := NewStreamCopier(mock) - - r, ok := x.(io.Reader) - if !ok { - t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x) - } - - var buf [23]byte - n, err := r.Read(buf[:]) - assert.True(t, mock.reads == 1) - assert.True(t, n == len(buf)) - assert.NoError(t, err) - assert.True(t, x.Count() == 23) -} diff --git a/zfs/zfs.go b/zfs/zfs.go index 2b21ef4..d2297be 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -354,63 +354,6 @@ func (a ZFSSendArgsUnvalidated) buildCommonSendArgs() ([]string, error) { return args, nil } -type ReadCloserCopier struct { - recorder readErrRecorder -} - -type readErrRecorder struct { - io.ReadCloser - readErr error -} - -type sendStreamCopierError struct { - isReadErr bool // if false, it's a write error - err error -} - -func (e sendStreamCopierError) Error() string { - if e.isReadErr { - return fmt.Sprintf("stream: read error: %s", e.err) - } else { - return fmt.Sprintf("stream: writer error: %s", e.err) - } -} - -func (e sendStreamCopierError) IsReadError() bool { return e.isReadErr } -func (e sendStreamCopierError) IsWriteError() bool { return !e.isReadErr } - -func (r *readErrRecorder) Read(p []byte) (n int, err error) { - n, err = r.ReadCloser.Read(p) - r.readErr = err - return n, err -} - -func NewReadCloserCopier(stream io.ReadCloser) *ReadCloserCopier { - return &ReadCloserCopier{recorder: readErrRecorder{stream, nil}} -} - -func (c *ReadCloserCopier) WriteStreamTo(w io.Writer) StreamCopierError { - debug("sendStreamCopier.WriteStreamTo: begin") - _, err := io.Copy(w, &c.recorder) - debug("sendStreamCopier.WriteStreamTo: copy done") - if err != nil { - if c.recorder.readErr != nil { - return sendStreamCopierError{isReadErr: true, err: c.recorder.readErr} - } else { - return sendStreamCopierError{isReadErr: false, err: err} - } - } - return nil -} - -func (c *ReadCloserCopier) Read(p []byte) (n int, err error) { - return c.recorder.Read(p) -} - -func (c *ReadCloserCopier) Close() error { - return c.recorder.ReadCloser.Close() -} - func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) { if capacity <= 0 { panic(fmt.Sprintf("capacity must be positive %v", capacity)) @@ -423,7 +366,7 @@ func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) { return stdoutReader, stdoutWriter, nil } -type sendStream struct { +type SendStream struct { cmd *zfscmd.Cmd kill context.CancelFunc @@ -433,7 +376,7 @@ type sendStream struct { opErr error } -func (s *sendStream) Read(p []byte) (n int, err error) { +func (s *SendStream) Read(p []byte) (n int, err error) { s.closeMtx.Lock() opErr := s.opErr s.closeMtx.Unlock() @@ -454,12 +397,12 @@ func (s *sendStream) Read(p []byte) (n int, err error) { return n, err } -func (s *sendStream) Close() error { +func (s *SendStream) Close() error { debug("sendStream: close called") return s.killAndWait(nil) } -func (s *sendStream) killAndWait(precedingReadErr error) error { +func (s *SendStream) killAndWait(precedingReadErr error) error { debug("sendStream: killAndWait enter") defer debug("sendStream: killAndWait leave") @@ -830,7 +773,7 @@ var ErrEncryptedSendNotSupported = fmt.Errorf("raw sends which are required for // (if from is "" a full ZFS send is done) // // Returns ErrEncryptedSendNotSupported if encrypted send is requested but not supported by CLI -func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*ReadCloserCopier, error) { +func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*SendStream, error) { args := make([]string, 0) args = append(args, "send") @@ -879,14 +822,14 @@ func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*ReadCloserCop // close our writing-end of the pipe so that we don't wait for ourselves when reading from the reading end stdoutWriter.Close() - stream := &sendStream{ + stream := &SendStream{ cmd: cmd, kill: cancel, stdoutReader: stdoutReader, stderrBuf: stderrBuf, } - return NewReadCloserCopier(stream), nil + return stream, nil } type DrySendType string @@ -1025,24 +968,6 @@ func ZFSSendDry(ctx context.Context, sendArgs ZFSSendArgsValidated) (_ *DrySendI return &si, nil } -type StreamCopierError interface { - error - IsReadError() bool - IsWriteError() bool -} - -type StreamCopier interface { - // WriteStreamTo writes the stream represented by this StreamCopier - // to the given io.Writer. - WriteStreamTo(w io.Writer) StreamCopierError - // Close must be called as soon as it is clear that no more data will - // be read from the StreamCopier. - // If StreamCopier gets its data from a connection, it might hold - // a lock on the connection until Close is called. Only closing ensures - // that the connection can be used afterwards. - Close() error -} - type RecvOptions struct { // Rollback to the oldest snapshot, destroy it, then perform `recv -F`. // Note that this doesn't change property values, i.e. an existing local property value will be kept. @@ -1067,7 +992,9 @@ func (e *ErrRecvResumeNotSupported) Error() string { return buf.String() } -func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier StreamCopier, opts RecvOptions) (err error) { +const RecvStderrBufSiz = 1 << 15 + +func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, stream io.ReadCloser, opts RecvOptions) (err error) { if err := v.ValidateInMemory(fs); err != nil { return errors.Wrap(err, "invalid version") @@ -1134,7 +1061,7 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier // cannot receive new filesystem stream: invalid backup stream stdout := bytes.NewBuffer(make([]byte, 0, 1024)) - stderr := bytes.NewBuffer(make([]byte, 0, 1024)) + stderr := bytes.NewBuffer(make([]byte, 0, RecvStderrBufSiz)) stdin, stdinWriter, err := pipeWithCapacityHint(ZFSRecvPipeCapacityHint) if err != nil { @@ -1162,9 +1089,10 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier debug("started") - copierErrChan := make(chan StreamCopierError) + copierErrChan := make(chan error) go func() { - copierErrChan <- streamCopier.WriteStreamTo(stdinWriter) + _, err := io.Copy(stdinWriter, stream) + copierErrChan <- err stdinWriter.Close() }() waitErrChan := make(chan error) @@ -1173,6 +1101,10 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier if err = cmd.Wait(); err != nil { if rtErr := tryRecvErrorWithResumeToken(ctx, stderr.String()); rtErr != nil { waitErrChan <- rtErr + } else if owErr := tryRecvDestroyOrOverwriteEncryptedErr(stderr.Bytes()); owErr != nil { + waitErrChan <- owErr + } else if readErr := tryRecvCannotReadFromStreamErr(stderr.Bytes()); readErr != nil { + waitErrChan <- readErr } else { waitErrChan <- &ZFSError{ Stderr: stderr.Bytes(), @@ -1183,22 +1115,23 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier } }() - // streamCopier always fails before or simultaneously with Wait - // thus receive from it first copierErr := <-copierErrChan debug("copierErr: %T %s", copierErr, copierErr) if copierErr != nil { + debug("killing zfs recv command after copierErr") cancelCmd() } waitErr := <-waitErrChan debug("waitErr: %T %s", waitErr, waitErr) + if copierErr == nil && waitErr == nil { return nil - } else if waitErr != nil && (copierErr == nil || copierErr.IsWriteError()) { - return waitErr // has more interesting info in that case + } else if _, isReadErr := waitErr.(*RecvCannotReadFromStreamErr); isReadErr { + return copierErr // likely network error reading from stream + } else { + return waitErr // almost always more interesting info. NOTE: do not wrap! } - return copierErr // if it's not a write error, the copier error is more interesting } type RecvFailedWithResumeTokenErr struct { @@ -1228,6 +1161,43 @@ func (e *RecvFailedWithResumeTokenErr) Error() string { return fmt.Sprintf("receive failed, resume token available: %s\n%#v", e.ResumeTokenRaw, e.ResumeTokenParsed) } +type RecvDestroyOrOverwriteEncryptedErr struct { + Msg string +} + +func (e *RecvDestroyOrOverwriteEncryptedErr) Error() string { + return e.Msg +} + +var recvDestroyOrOverwriteEncryptedErrRe = regexp.MustCompile(`^(cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one)`) + +func tryRecvDestroyOrOverwriteEncryptedErr(stderr []byte) *RecvDestroyOrOverwriteEncryptedErr { + debug("tryRecvDestroyOrOverwriteEncryptedErr: %v", stderr) + m := recvDestroyOrOverwriteEncryptedErrRe.FindSubmatch(stderr) + if m == nil { + return nil + } + return &RecvDestroyOrOverwriteEncryptedErr{Msg: string(m[1])} +} + +type RecvCannotReadFromStreamErr struct { + Msg string +} + +func (e *RecvCannotReadFromStreamErr) Error() string { + return e.Msg +} + +var reRecvCannotReadFromStreamErr = regexp.MustCompile(`^(cannot receive: failed to read from stream)$`) + +func tryRecvCannotReadFromStreamErr(stderr []byte) *RecvCannotReadFromStreamErr { + m := reRecvCannotReadFromStreamErr.FindSubmatch(stderr) + if m == nil { + return nil + } + return &RecvCannotReadFromStreamErr{Msg: string(m[1])} +} + type ClearResumeTokenError struct { ZFSOutput []byte CmdError error diff --git a/zfs/zfs_test.go b/zfs/zfs_test.go index 606ab54..9871640 100644 --- a/zfs/zfs_test.go +++ b/zfs/zfs_test.go @@ -2,9 +2,11 @@ package zfs import ( "context" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestZFSListHandlesProducesZFSErrorOnNonZeroExit(t *testing.T) { @@ -259,3 +261,12 @@ size 10518512 }) } } + +func TestTryRecvDestroyOrOverwriteEncryptedErr(t *testing.T) { + msg := "cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one\n" + assert.GreaterOrEqual(t, RecvStderrBufSiz, len(msg)) + + err := tryRecvDestroyOrOverwriteEncryptedErr([]byte(msg)) + require.NotNil(t, err) + assert.EqualError(t, err, strings.TrimSpace(msg)) +} diff --git a/zfs/zfscmd/zfscmd_platform_test.go b/zfs/zfscmd/zfscmd_platform_test.go index 2a34366..5bb5d6c 100644 --- a/zfs/zfscmd/zfscmd_platform_test.go +++ b/zfs/zfscmd/zfscmd_platform_test.go @@ -5,11 +5,14 @@ import ( "bytes" "context" "io" + "os" "os/exec" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zrepl/zrepl/util/circlog" ) const testBin = "./zfscmd_platform_test.bash" @@ -85,5 +88,37 @@ func TestCmdProcessState(t *testing.T) { require.True(t, ok) require.NotNil(t, ee.ProcessState) require.Contains(t, ee.Error(), "killed") +} + +func TestSigpipe(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cmd := CommandContext(ctx, "bash", "-c", "sleep 5; echo invalid input; exit 23") + r, w, err := os.Pipe() + require.NoError(t, err) + output := circlog.MustNewCircularLog(1 << 20) + cmd.SetStdio(Stdio{ + Stdin: r, + Stdout: output, + Stderr: output, + }) + err = cmd.Start() + require.NoError(t, err) + err = r.Close() + require.NoError(t, err) + + // the script doesn't read stdin, but this input is almost certainly smaller than the pipe buffer + const LargerThanPipeBuffer = 1 << 21 + _, err = io.Copy(w, bytes.NewBuffer(bytes.Repeat([]byte("i"), LargerThanPipeBuffer))) + // => io.Copy is going to block because the pipe buffer is full and the + // script is not reading from it + // => the script is going to exit after 5s + // => we should expect a broken pipe error from the copier's perspective + t.Logf("copy err = %T: %s", err, err) + require.NotNil(t, err) + require.True(t, strings.Contains(err.Error(), "broken pipe")) + + err = cmd.Wait() + require.EqualError(t, err, "exit status 23") }