[#277] rpc + zfs: drop zfs.StreamCopier, use io.ReadCloser instead

This commit is contained in:
Christian Schwarz 2020-04-10 21:58:28 +02:00
parent 0280727985
commit 0e5c77d2be
18 changed files with 243 additions and 414 deletions

View File

@ -4,6 +4,7 @@ package endpoint
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"path" "path"
"sync" "sync"
@ -241,7 +242,7 @@ func sendArgsFromPDUAndValidateExistsAndGetVersion(ctx context.Context, fs strin
return version, nil return version, nil
} }
func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
_, err := s.filterCheckFS(r.Filesystem) _, err := s.filterCheckFS(r.Filesystem)
if err != nil { if err != nil {
@ -339,11 +340,11 @@ func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.St
// step holds & replication cursor released / moved forward in s.SendCompleted => s.moveCursorAndReleaseSendHolds // step holds & replication cursor released / moved forward in s.SendCompleted => s.moveCursorAndReleaseSendHolds
streamCopier, err := zfs.ZFSSend(ctx, sendArgs) sendStream, err := zfs.ZFSSend(ctx, sendArgs)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "zfs send failed") return nil, nil, errors.Wrap(err, "zfs send failed")
} }
return res, streamCopier, nil return res, sendStream, nil
} }
func (p *Sender) SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error) { func (p *Sender) SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error) {
@ -476,7 +477,7 @@ func (p *Sender) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCurs
return &pdu.ReplicationCursorRes{Result: &pdu.ReplicationCursorRes_Guid{Guid: cursor.Guid}}, nil return &pdu.ReplicationCursorRes{Result: &pdu.ReplicationCursorRes_Guid{Guid: cursor.Guid}}, nil
} }
func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { func (p *Sender) Receive(ctx context.Context, r *pdu.ReceiveReq, _ io.ReadCloser) (*pdu.ReceiveRes, error) {
return nil, fmt.Errorf("sender does not implement Receive()") return nil, fmt.Errorf("sender does not implement Receive()")
} }
@ -680,13 +681,13 @@ func (s *Receiver) ReplicationCursor(context.Context, *pdu.ReplicationCursorReq)
return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver") return nil, fmt.Errorf("ReplicationCursor not implemented for Receiver")
} }
func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { func (s *Receiver) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
return nil, nil, fmt.Errorf("receiver does not implement Send()") return nil, nil, fmt.Errorf("receiver does not implement Send()")
} }
var maxConcurrentZFSRecvSemaphore = semaphore.New(envconst.Int64("ZREPL_ENDPOINT_MAX_CONCURRENT_RECV", 10)) var maxConcurrentZFSRecvSemaphore = semaphore.New(envconst.Int64("ZREPL_ENDPOINT_MAX_CONCURRENT_RECV", 10))
func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) { func (s *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error) {
getLogger(ctx).Debug("incoming Receive") getLogger(ctx).Debug("incoming Receive")
defer receive.Close() defer receive.Close()

View File

@ -136,7 +136,7 @@ func makeResumeSituation(ctx *platformtest.Context, src dummySnapshotSituation,
return situation return situation
} }
limitedCopier := zfs.NewReadCloserCopier(limitio.ReadCloser(copier, src.dummyDataLen/2)) limitedCopier := limitio.ReadCloser(copier, src.dummyDataLen/2)
defer limitedCopier.Close() defer limitedCopier.Close()
require.NotNil(ctx, sendArgs.To) require.NotNil(ctx, sendArgs.To)

View File

@ -35,7 +35,7 @@ func SendArgsValidationEncryptedSendOfUnencryptedDatasetForbidden(ctx *platformt
ResumeToken: "", ResumeToken: "",
}.Validate(ctx) }.Validate(ctx)
var stream *zfs.ReadCloserCopier var stream *zfs.SendStream
if err == nil { if err == nil {
stream, err = zfs.ZFSSend(ctx, sendArgs) // no shadow stream, err = zfs.ZFSSend(ctx, sendArgs) // no shadow
if err == nil { if err == nil {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"sync" "sync"
"time" "time"
@ -38,7 +39,7 @@ type Sender interface {
// If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before // If a non-nil io.ReadCloser is returned, it is guaranteed to be closed before
// any next call to the parent github.com/zrepl/zrepl/replication.Endpoint. // any next call to the parent github.com/zrepl/zrepl/replication.Endpoint.
// If the send request is for dry run the io.ReadCloser will be nil // If the send request is for dry run the io.ReadCloser will be nil
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error)
SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error) SendCompleted(ctx context.Context, r *pdu.SendCompletedReq) (*pdu.SendCompletedRes, error)
ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error) ReplicationCursor(ctx context.Context, req *pdu.ReplicationCursorReq) (*pdu.ReplicationCursorRes, error)
} }
@ -47,7 +48,7 @@ type Receiver interface {
Endpoint Endpoint
// Receive sends r and sendStream (the latter containing a ZFS send stream) // Receive sends r and sendStream (the latter containing a ZFS send stream)
// to the parent github.com/zrepl/zrepl/replication.Endpoint. // to the parent github.com/zrepl/zrepl/replication.Endpoint.
Receive(ctx context.Context, req *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) Receive(ctx context.Context, req *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error)
} }
type PlannerPolicy struct { type PlannerPolicy struct {
@ -162,7 +163,7 @@ type Step struct {
// byteCounter is nil initially, and set later in Step.doReplication // byteCounter is nil initially, and set later in Step.doReplication
// => concurrent read of that pointer from Step.ReportInfo must be protected // => concurrent read of that pointer from Step.ReportInfo must be protected
byteCounter bytecounter.StreamCopier byteCounter bytecounter.ReadCloser
byteCounterMtx chainlock.L byteCounterMtx chainlock.L
} }
@ -606,19 +607,19 @@ func (s *Step) doReplication(ctx context.Context) error {
sr := s.buildSendRequest(false) sr := s.buildSendRequest(false)
log.Debug("initiate send request") log.Debug("initiate send request")
sres, sstreamCopier, err := s.sender.Send(ctx, sr) sres, stream, err := s.sender.Send(ctx, sr)
if err != nil { if err != nil {
log.WithError(err).Error("send request failed") log.WithError(err).Error("send request failed")
return err return err
} }
if sstreamCopier == nil { if stream == nil {
err := errors.New("send request did not return a stream, broken endpoint implementation") err := errors.New("send request did not return a stream, broken endpoint implementation")
return err return err
} }
defer sstreamCopier.Close() defer stream.Close()
// Install a byte counter to track progress + for status report // Install a byte counter to track progress + for status report
byteCountingStream := bytecounter.NewStreamCopier(sstreamCopier) byteCountingStream := bytecounter.NewReadCloser(stream)
s.byteCounterMtx.Lock() s.byteCounterMtx.Lock()
s.byteCounter = byteCountingStream s.byteCounter = byteCountingStream
s.byteCounterMtx.Unlock() s.byteCounterMtx.Unlock()

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"strings" "strings"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -11,7 +12,6 @@ import (
"github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/replication/logic/pdu"
"github.com/zrepl/zrepl/rpc/dataconn/stream" "github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/zfs"
) )
type Client struct { type Client struct {
@ -26,7 +26,7 @@ func NewClient(connecter transport.Connecter, log Logger) *Client {
} }
} }
func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, streamCopier zfs.StreamCopier) error { func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, req proto.Message, stream io.ReadCloser) error {
var buf bytes.Buffer var buf bytes.Buffer
_, memErr := buf.WriteString(endpoint) _, memErr := buf.WriteString(endpoint)
@ -46,8 +46,8 @@ func (c *Client) send(ctx context.Context, conn *stream.Conn, endpoint string, r
return err return err
} }
if streamCopier != nil { if stream != nil {
return conn.SendStream(ctx, streamCopier, ZFSStream) return conn.SendStream(ctx, stream, ZFSStream)
} else { } else {
return nil return nil
} }
@ -109,7 +109,7 @@ func (c *Client) putWire(conn *stream.Conn) {
} }
} }
func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
conn, err := c.getWire(ctx) conn, err := c.getWire(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -130,16 +130,19 @@ func (c *Client) ReqSend(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, z
return nil, nil, err return nil, nil, err
} }
var copier zfs.StreamCopier = nil var stream io.ReadCloser
if !req.DryRun { if !req.DryRun {
putWireOnReturn = false putWireOnReturn = false
copier = &streamCopier{streamConn: conn, closeStreamOnClose: true} stream, err = conn.ReadStream(ZFSStream, true) // no shadow
if err != nil {
return nil, nil, err
}
} }
return &res, copier, nil return &res, stream, nil
} }
func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) {
defer c.log.Debug("ReqRecv returns") defer c.log.Debug("ReqRecv returns")
conn, err := c.getWire(ctx) conn, err := c.getWire(ctx)
@ -166,7 +169,7 @@ func (c *Client) ReqRecv(ctx context.Context, req *pdu.ReceiveReq, streamCopier
sendErrChan := make(chan error) sendErrChan := make(chan error)
go func() { go func() {
if err := c.send(ctx, conn, EndpointRecv, req, streamCopier); err != nil { if err := c.send(ctx, conn, EndpointRecv, req, stream); err != nil {
sendErrChan <- err sendErrChan <- err
} else { } else {
sendErrChan <- nil sendErrChan <- nil

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
@ -11,7 +12,6 @@ import (
"github.com/zrepl/zrepl/replication/logic/pdu" "github.com/zrepl/zrepl/replication/logic/pdu"
"github.com/zrepl/zrepl/rpc/dataconn/stream" "github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/zfs"
) )
// WireInterceptor has a chance to exchange the context and connection on each client connection. // WireInterceptor has a chance to exchange the context and connection on each client connection.
@ -21,11 +21,11 @@ type WireInterceptor func(ctx context.Context, rawConn *transport.AuthConn) (con
type Handler interface { type Handler interface {
// Send handles a SendRequest. // Send handles a SendRequest.
// The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run. // The returned io.ReadCloser is allowed to be nil, for example if the requested Send is a dry-run.
Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error)
// Receive handles a ReceiveRequest. // Receive handles a ReceiveRequest.
// It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout // It is guaranteed that Server calls Receive with a stream that holds the IdleConnTimeout
// configured in ServerConfig.Shared.IdleConnTimeout. // configured in ServerConfig.Shared.IdleConnTimeout.
Receive(ctx context.Context, r *pdu.ReceiveReq, receive zfs.StreamCopier) (*pdu.ReceiveRes, error) Receive(ctx context.Context, r *pdu.ReceiveReq, receive io.ReadCloser) (*pdu.ReceiveRes, error)
// PingDataconn handles a PingReq // PingDataconn handles a PingReq
PingDataconn(ctx context.Context, r *pdu.PingReq) (*pdu.PingRes, error) PingDataconn(ctx context.Context, r *pdu.PingReq) (*pdu.PingRes, error)
} }
@ -111,7 +111,7 @@ func (s *Server) serveConn(nc *transport.AuthConn) {
s.log.WithField("endpoint", endpoint).Debug("calling handler") s.log.WithField("endpoint", endpoint).Debug("calling handler")
var res proto.Message var res proto.Message
var sendStream zfs.StreamCopier var sendStream io.ReadCloser
var handlerErr error var handlerErr error
switch endpoint { switch endpoint {
case EndpointSend: case EndpointSend:
@ -127,7 +127,12 @@ func (s *Server) serveConn(nc *transport.AuthConn) {
s.log.WithError(err).Error("cannot unmarshal receive request") s.log.WithError(err).Error("cannot unmarshal receive request")
return return
} }
res, handlerErr = s.h.Receive(ctx, &req, &streamCopier{streamConn: c, closeStreamOnClose: false}) // SHADOWING stream, err := c.ReadStream(ZFSStream, false)
if err != nil {
s.log.WithError(err).Error("cannot open stream in receive request")
return
}
res, handlerErr = s.h.Receive(ctx, &req, stream) // SHADOWING
case EndpointPing: case EndpointPing:
var req pdu.PingReq var req pdu.PingReq
if err := proto.Unmarshal(reqStructured, &req); err != nil { if err := proto.Unmarshal(reqStructured, &req); err != nil {

View File

@ -1,12 +1,7 @@
package dataconn package dataconn
import ( import (
"io"
"sync"
"time" "time"
"github.com/zrepl/zrepl/rpc/dataconn/stream"
"github.com/zrepl/zrepl/zfs"
) )
const ( const (
@ -39,33 +34,3 @@ const (
responseHeaderHandlerOk = "HANDLER OK\n" responseHeaderHandlerOk = "HANDLER OK\n"
responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n" responseHeaderHandlerErrorPrefix = "HANDLER ERROR:\n"
) )
type streamCopier struct {
mtx sync.Mutex
used bool
streamConn *stream.Conn
closeStreamOnClose bool
}
// WriteStreamTo implements zfs.StreamCopier
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
s.mtx.Lock()
defer s.mtx.Unlock()
if s.used {
panic("streamCopier used multiple times")
}
s.used = true
return s.streamConn.ReadStreamInto(w, ZFSStream)
}
// Close implements zfs.StreamCopier
func (s *streamCopier) Close() error {
// only record the close here, what we do actually depends on whether
// the streamCopier is instantiated server-side or client-side
s.mtx.Lock()
defer s.mtx.Unlock()
if s.closeStreamOnClose {
return s.streamConn.Close()
}
return nil
}

View File

@ -29,7 +29,6 @@ import (
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
"github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/util/devnoop" "github.com/zrepl/zrepl/util/devnoop"
"github.com/zrepl/zrepl/zfs"
) )
func orDie(err error) { func orDie(err error) {
@ -42,23 +41,9 @@ type readerStreamCopier struct{ io.Reader }
func (readerStreamCopier) Close() error { return nil } func (readerStreamCopier) Close() error { return nil }
type readerStreamCopierErr struct {
error
}
func (readerStreamCopierErr) IsReadError() bool { return false }
func (readerStreamCopierErr) IsWriteError() bool { return true }
func (c readerStreamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
var buf [1 << 21]byte
_, err := io.CopyBuffer(w, c.Reader, buf[:])
// always assume write error
return readerStreamCopierErr{err}
}
type devNullHandler struct{} type devNullHandler struct{}
func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
var res pdu.SendRes var res pdu.SendRes
if args.devnoopReader { if args.devnoopReader {
return &res, readerStreamCopier{devnoop.Get()}, nil return &res, readerStreamCopier{devnoop.Get()}, nil
@ -67,12 +52,12 @@ func (devNullHandler) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, z
} }
} }
func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream zfs.StreamCopier) (*pdu.ReceiveRes, error) { func (devNullHandler) Receive(ctx context.Context, r *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) {
var out io.Writer = os.Stdout var out io.Writer = os.Stdout
if args.devnoopWriter { if args.devnoopWriter {
out = devnoop.Get() out = devnoop.Get()
} }
err := stream.WriteStreamTo(out) _, err := io.Copy(out, stream)
var res pdu.ReceiveRes var res pdu.ReceiveRes
return &res, err return &res, err
} }
@ -172,7 +157,7 @@ func client() {
req := pdu.SendReq{} req := pdu.SendReq{}
_, stream, err := client.ReqSend(ctx, &req) _, stream, err := client.ReqSend(ctx, &req)
orDie(err) orDie(err)
err = stream.WriteStreamTo(os.Stdout) _, err = io.Copy(os.Stdout, stream)
orDie(err) orDie(err)
case "recv": case "recv":
var r io.Reader = os.Stdin var r io.Reader = os.Stdin

View File

@ -14,7 +14,6 @@ import (
"github.com/zrepl/zrepl/rpc/dataconn/base2bufpool" "github.com/zrepl/zrepl/rpc/dataconn/base2bufpool"
"github.com/zrepl/zrepl/rpc/dataconn/frameconn" "github.com/zrepl/zrepl/rpc/dataconn/frameconn"
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
"github.com/zrepl/zrepl/zfs"
) )
type Logger = logger.Logger type Logger = logger.Logger
@ -198,7 +197,7 @@ func (e ReadStreamError) Temporary() bool {
return false return false
} }
var _ zfs.StreamCopierError = &ReadStreamError{} var _ net.Error = &ReadStreamError{}
func (e ReadStreamError) IsReadError() bool { func (e ReadStreamError) IsReadError() bool {
return e.Kind != ReadStreamErrorKindWrite return e.Kind != ReadStreamErrorKindWrite

View File

@ -14,7 +14,6 @@ import (
"github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn" "github.com/zrepl/zrepl/rpc/dataconn/heartbeatconn"
"github.com/zrepl/zrepl/rpc/dataconn/timeoutconn" "github.com/zrepl/zrepl/rpc/dataconn/timeoutconn"
"github.com/zrepl/zrepl/zfs"
) )
type Conn struct { type Conn struct {
@ -40,15 +39,7 @@ type Conn struct {
var readMessageSentinel = fmt.Errorf("read stream complete") var readMessageSentinel = fmt.Errorf("read stream complete")
type writeStreamToErrorUnknownState struct{} var errWriteStreamToErrorUnknownState = fmt.Errorf("dataconn read stream: connection is in unknown state")
func (e writeStreamToErrorUnknownState) Error() string {
return "dataconn read stream: connection is in unknown state"
}
func (e writeStreamToErrorUnknownState) IsReadError() bool { return true }
func (e writeStreamToErrorUnknownState) IsWriteError() bool { return false }
func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn { func Wrap(nc timeoutconn.Wire, sendHeartbeatInterval, peerTimeout time.Duration) *Conn {
hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout) hc := heartbeatconn.Wrap(nc, sendHeartbeatInterval, peerTimeout)
@ -123,14 +114,28 @@ func (c *Conn) ReadStreamedMessage(ctx context.Context, maxSize uint32, frameTyp
} }
} }
type StreamReader struct {
*io.PipeReader
conn *Conn
closeConnOnClose bool
}
func (r *StreamReader) Close() error {
err := r.PipeReader.Close()
if r.closeConnOnClose {
r.conn.Close() // TODO error logging
}
return err
}
// WriteStreamTo reads a stream from Conn and writes it to w. // WriteStreamTo reads a stream from Conn and writes it to w.
func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopierError) { func (c *Conn) ReadStream(frameType uint32, closeConnOnClose bool) (_ *StreamReader, err error) {
// if we are closed while writing, return that as an error // if we are closed while writing, return that as an error
if closeGuard, cse := c.closeState.RWEntry(); cse != nil { if closeGuard, cse := c.closeState.RWEntry(); cse != nil {
return cse return nil, cse
} else { } else {
defer func(err *zfs.StreamCopierError) { defer func(err *error) {
if closed := closeGuard.RWExit(); closed != nil { if closed := closeGuard.RWExit(); closed != nil {
*err = closed *err = closed
} }
@ -138,18 +143,23 @@ func (c *Conn) ReadStreamInto(w io.Writer, frameType uint32) (err zfs.StreamCopi
} }
c.readMtx.Lock() c.readMtx.Lock()
defer c.readMtx.Unlock()
if !c.readClean { if !c.readClean {
return writeStreamToErrorUnknownState{} return nil, errWriteStreamToErrorUnknownState
} }
var rse *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
c.readClean = isConnCleanAfterRead(rse)
// https://golang.org/doc/faq#nil_error r, w := io.Pipe()
if rse == nil { go func() {
return nil defer c.readMtx.Unlock()
var err *ReadStreamError = readStream(c.frameReads, c.hc, w, frameType)
if err != nil {
_ = w.CloseWithError(err) // doc guarantees that error will always be nil
} else {
w.Close()
} }
return rse c.readClean = isConnCleanAfterRead(err)
}()
return &StreamReader{PipeReader: r, conn: c, closeConnOnClose: closeConnOnClose}, nil
} }
func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) (err error) { func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameType uint32) (err error) {
@ -178,7 +188,7 @@ func (c *Conn) WriteStreamedMessage(ctx context.Context, buf io.Reader, frameTyp
return errConn return errConn
} }
func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType uint32) (err error) { func (c *Conn) SendStream(ctx context.Context, stream io.ReadCloser, frameType uint32) (err error) {
// if we are closed while reading, return that as an error // if we are closed while reading, return that as an error
if closeGuard, cse := c.closeState.RWEntry(); cse != nil { if closeGuard, cse := c.closeState.RWEntry(); cse != nil {
@ -197,49 +207,17 @@ func (c *Conn) SendStream(ctx context.Context, src zfs.StreamCopier, frameType u
return fmt.Errorf("dataconn send stream: connection is in unknown state") return fmt.Errorf("dataconn send stream: connection is in unknown state")
} }
// avoid io.Pipe if zfs.StreamCopier is an io.Reader errStream, errConn := writeStream(ctx, c.hc, stream, frameType)
var r io.Reader
var w *io.PipeWriter
streamCopierErrChan := make(chan zfs.StreamCopierError, 1)
if reader, ok := src.(io.Reader); ok {
r = reader
streamCopierErrChan <- nil
close(streamCopierErrChan)
} else {
r, w = io.Pipe()
go func() {
streamCopierErrChan <- src.WriteStreamTo(w)
w.Close()
}()
}
type writeStreamRes struct { c.writeClean = isConnCleanAfterWrite(errConn) // TODO correct?
errStream, errConn error
}
writeStreamErrChan := make(chan writeStreamRes, 1)
go func() {
var res writeStreamRes
res.errStream, res.errConn = writeStream(ctx, c.hc, r, frameType)
if w != nil {
_ = w.CloseWithError(res.errStream) // always returns nil
}
writeStreamErrChan <- res
}()
writeRes := <-writeStreamErrChan if errStream != nil {
streamCopierErr := <-streamCopierErrChan return errStream
c.writeClean = isConnCleanAfterWrite(writeRes.errConn) // TODO correct? } else if errConn != nil {
if streamCopierErr != nil && streamCopierErr.IsReadError() { return errConn
return streamCopierErr // something on our side is bad
} else {
if writeRes.errStream != nil {
return writeRes.errStream
} else if writeRes.errConn != nil {
return writeRes.errConn
} }
// TODO combined error? // TODO combined error?
return streamCopierErr return nil
}
} }
type closeState struct { type closeState struct {
@ -248,15 +226,11 @@ type closeState struct {
type closeStateErrConnectionClosed struct{} type closeStateErrConnectionClosed struct{}
var _ zfs.StreamCopierError = (*closeStateErrConnectionClosed)(nil)
var _ error = (*closeStateErrConnectionClosed)(nil)
var _ net.Error = (*closeStateErrConnectionClosed)(nil) var _ net.Error = (*closeStateErrConnectionClosed)(nil)
func (e *closeStateErrConnectionClosed) Error() string { func (e *closeStateErrConnectionClosed) Error() string {
return "connection closed" return "connection closed"
} }
func (e *closeStateErrConnectionClosed) IsReadError() bool { return true }
func (e *closeStateErrConnectionClosed) IsWriteError() bool { return true }
func (e *closeStateErrConnectionClosed) Timeout() bool { return false } func (e *closeStateErrConnectionClosed) Timeout() bool { return false }
func (e *closeStateErrConnectionClosed) Temporary() bool { return false } func (e *closeStateErrConnectionClosed) Temporary() bool { return false }
@ -273,7 +247,7 @@ type closeStateEntry struct {
entryCount uint32 entryCount uint32
} }
func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) { func (s *closeState) RWEntry() (e *closeStateEntry, err net.Error) {
entry := &closeStateEntry{s, atomic.LoadUint32(&s.closeCount)} entry := &closeStateEntry{s, atomic.LoadUint32(&s.closeCount)}
if entry.entryCount > 0 { if entry.entryCount > 0 {
return nil, &closeStateErrConnectionClosed{} return nil, &closeStateErrConnectionClosed{}
@ -281,7 +255,7 @@ func (s *closeState) RWEntry() (e *closeStateEntry, err zfs.StreamCopierError) {
return entry, nil return entry, nil
} }
func (e *closeStateEntry) RWExit() zfs.StreamCopierError { func (e *closeStateEntry) RWExit() net.Error {
if atomic.LoadUint32(&e.entryCount) == e.entryCount { if atomic.LoadUint32(&e.entryCount) == e.entryCount {
// no calls to Close() while running rw operation // no calls to Close() while running rw operation
return nil return nil

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -20,7 +21,6 @@ import (
"github.com/zrepl/zrepl/rpc/versionhandshake" "github.com/zrepl/zrepl/rpc/versionhandshake"
"github.com/zrepl/zrepl/transport" "github.com/zrepl/zrepl/transport"
"github.com/zrepl/zrepl/util/envconst" "github.com/zrepl/zrepl/util/envconst"
"github.com/zrepl/zrepl/zfs"
) )
// Client implements the active side of a replication setup. // Client implements the active side of a replication setup.
@ -82,22 +82,22 @@ func (c *Client) Close() {
// callers must ensure that the returned io.ReadCloser is closed // callers must ensure that the returned io.ReadCloser is closed
// TODO expose dataClient interface to the outside world // TODO expose dataClient interface to the outside world
func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.StreamCopier, error) { func (c *Client) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) {
// TODO the returned sendStream may return a read error created by the remote side // TODO the returned sendStream may return a read error created by the remote side
res, streamCopier, err := c.dataClient.ReqSend(ctx, r) res, stream, err := c.dataClient.ReqSend(ctx, r)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if streamCopier == nil { if stream == nil {
return res, nil, nil return res, nil, nil
} }
return res, streamCopier, nil return res, stream, nil
} }
func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, streamCopier zfs.StreamCopier) (*pdu.ReceiveRes, error) { func (c *Client) Receive(ctx context.Context, req *pdu.ReceiveReq, stream io.ReadCloser) (*pdu.ReceiveRes, error) {
return c.dataClient.ReqRecv(ctx, req, streamCopier) return c.dataClient.ReqRecv(ctx, req, stream)
} }
func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) { func (c *Client) ListFilesystems(ctx context.Context, in *pdu.ListFilesystemReq) (*pdu.ListFilesystemRes, error) {

View File

@ -0,0 +1,39 @@
package bytecounter
import (
"io"
"sync/atomic"
)
// ReadCloser wraps an io.ReadCloser, reimplementing
// its interface and counting the bytes written to during copying.
type ReadCloser interface {
io.ReadCloser
Count() int64
}
// NewReadCloser wraps rc.
func NewReadCloser(rc io.ReadCloser) ReadCloser {
return &readCloser{rc, 0}
}
type readCloser struct {
rc io.ReadCloser
count int64
}
func (r *readCloser) Count() int64 {
return atomic.LoadInt64(&r.count)
}
var _ io.ReadCloser = &readCloser{}
func (r *readCloser) Close() error {
return r.rc.Close()
}
func (r *readCloser) Read(p []byte) (int, error) {
n, err := r.rc.Read(p)
atomic.AddInt64(&r.count, int64(n))
return n, err
}

View File

@ -1,49 +0,0 @@
package bytecounter
import (
"io"
"sync/atomic"
"time"
)
type ByteCounterReader struct {
reader io.ReadCloser
// called & accessed synchronously during Read, no external access
cb func(full int64)
cbEvery time.Duration
lastCbAt time.Time
// set atomically because it may be read by multiple threads
bytes int64
}
func NewByteCounterReader(reader io.ReadCloser) *ByteCounterReader {
return &ByteCounterReader{
reader: reader,
}
}
func (b *ByteCounterReader) SetCallback(every time.Duration, cb func(full int64)) {
b.cbEvery = every
b.cb = cb
}
func (b *ByteCounterReader) Close() error {
return b.reader.Close()
}
func (b *ByteCounterReader) Read(p []byte) (n int, err error) {
n, err = b.reader.Read(p)
full := atomic.AddInt64(&b.bytes, int64(n))
now := time.Now()
if b.cb != nil && now.Sub(b.lastCbAt) > b.cbEvery {
b.cb(full)
b.lastCbAt = now
}
return n, err
}
func (b *ByteCounterReader) Bytes() int64 {
return atomic.LoadInt64(&b.bytes)
}

View File

@ -1,71 +0,0 @@
package bytecounter
import (
"io"
"sync/atomic"
"github.com/zrepl/zrepl/zfs"
)
// StreamCopier wraps a zfs.StreamCopier, reimplementing
// its interface and counting the bytes written to during copying.
type StreamCopier interface {
zfs.StreamCopier
Count() int64
}
// NewStreamCopier wraps sc into a StreamCopier.
// If sc is io.Reader, it is guaranteed that the returned StreamCopier
// implements that interface, too.
func NewStreamCopier(sc zfs.StreamCopier) StreamCopier {
bsc := &streamCopier{sc, 0}
if scr, ok := sc.(io.Reader); ok {
return streamCopierAndReader{bsc, scr}
} else {
return bsc
}
}
type streamCopier struct {
sc zfs.StreamCopier
count int64
}
// proxy writer used by streamCopier
type streamCopierWriter struct {
parent *streamCopier
w io.Writer
}
func (w streamCopierWriter) Write(p []byte) (n int, err error) {
n, err = w.w.Write(p)
atomic.AddInt64(&w.parent.count, int64(n))
return
}
func (s *streamCopier) Count() int64 {
return atomic.LoadInt64(&s.count)
}
var _ zfs.StreamCopier = &streamCopier{}
func (s streamCopier) Close() error {
return s.sc.Close()
}
func (s *streamCopier) WriteStreamTo(w io.Writer) zfs.StreamCopierError {
ww := streamCopierWriter{s, w}
return s.sc.WriteStreamTo(ww)
}
// a streamCopier whose underlying sc is an io.Reader
type streamCopierAndReader struct {
*streamCopier
asReader io.Reader
}
func (scr streamCopierAndReader) Read(p []byte) (int, error) {
n, err := scr.asReader.Read(p)
atomic.AddInt64(&scr.streamCopier.count, int64(n))
return n, err
}

View File

@ -1,39 +0,0 @@
package bytecounter
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zrepl/zrepl/zfs"
)
type mockStreamCopierAndReader struct {
zfs.StreamCopier // to satisfy interface
reads int
}
func (r *mockStreamCopierAndReader) Read(p []byte) (int, error) {
r.reads++
return len(p), nil
}
var _ io.Reader = &mockStreamCopierAndReader{}
func TestNewStreamCopierReexportsReader(t *testing.T) {
mock := &mockStreamCopierAndReader{}
x := NewStreamCopier(mock)
r, ok := x.(io.Reader)
if !ok {
t.Fatalf("%T does not implement io.Reader, hence reader cannout have been wrapped", x)
}
var buf [23]byte
n, err := r.Read(buf[:])
assert.True(t, mock.reads == 1)
assert.True(t, n == len(buf))
assert.NoError(t, err)
assert.True(t, x.Count() == 23)
}

View File

@ -354,63 +354,6 @@ func (a ZFSSendArgsUnvalidated) buildCommonSendArgs() ([]string, error) {
return args, nil return args, nil
} }
type ReadCloserCopier struct {
recorder readErrRecorder
}
type readErrRecorder struct {
io.ReadCloser
readErr error
}
type sendStreamCopierError struct {
isReadErr bool // if false, it's a write error
err error
}
func (e sendStreamCopierError) Error() string {
if e.isReadErr {
return fmt.Sprintf("stream: read error: %s", e.err)
} else {
return fmt.Sprintf("stream: writer error: %s", e.err)
}
}
func (e sendStreamCopierError) IsReadError() bool { return e.isReadErr }
func (e sendStreamCopierError) IsWriteError() bool { return !e.isReadErr }
func (r *readErrRecorder) Read(p []byte) (n int, err error) {
n, err = r.ReadCloser.Read(p)
r.readErr = err
return n, err
}
func NewReadCloserCopier(stream io.ReadCloser) *ReadCloserCopier {
return &ReadCloserCopier{recorder: readErrRecorder{stream, nil}}
}
func (c *ReadCloserCopier) WriteStreamTo(w io.Writer) StreamCopierError {
debug("sendStreamCopier.WriteStreamTo: begin")
_, err := io.Copy(w, &c.recorder)
debug("sendStreamCopier.WriteStreamTo: copy done")
if err != nil {
if c.recorder.readErr != nil {
return sendStreamCopierError{isReadErr: true, err: c.recorder.readErr}
} else {
return sendStreamCopierError{isReadErr: false, err: err}
}
}
return nil
}
func (c *ReadCloserCopier) Read(p []byte) (n int, err error) {
return c.recorder.Read(p)
}
func (c *ReadCloserCopier) Close() error {
return c.recorder.ReadCloser.Close()
}
func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) { func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) {
if capacity <= 0 { if capacity <= 0 {
panic(fmt.Sprintf("capacity must be positive %v", capacity)) panic(fmt.Sprintf("capacity must be positive %v", capacity))
@ -423,7 +366,7 @@ func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) {
return stdoutReader, stdoutWriter, nil return stdoutReader, stdoutWriter, nil
} }
type sendStream struct { type SendStream struct {
cmd *zfscmd.Cmd cmd *zfscmd.Cmd
kill context.CancelFunc kill context.CancelFunc
@ -433,7 +376,7 @@ type sendStream struct {
opErr error opErr error
} }
func (s *sendStream) Read(p []byte) (n int, err error) { func (s *SendStream) Read(p []byte) (n int, err error) {
s.closeMtx.Lock() s.closeMtx.Lock()
opErr := s.opErr opErr := s.opErr
s.closeMtx.Unlock() s.closeMtx.Unlock()
@ -454,12 +397,12 @@ func (s *sendStream) Read(p []byte) (n int, err error) {
return n, err return n, err
} }
func (s *sendStream) Close() error { func (s *SendStream) Close() error {
debug("sendStream: close called") debug("sendStream: close called")
return s.killAndWait(nil) return s.killAndWait(nil)
} }
func (s *sendStream) killAndWait(precedingReadErr error) error { func (s *SendStream) killAndWait(precedingReadErr error) error {
debug("sendStream: killAndWait enter") debug("sendStream: killAndWait enter")
defer debug("sendStream: killAndWait leave") defer debug("sendStream: killAndWait leave")
@ -830,7 +773,7 @@ var ErrEncryptedSendNotSupported = fmt.Errorf("raw sends which are required for
// (if from is "" a full ZFS send is done) // (if from is "" a full ZFS send is done)
// //
// Returns ErrEncryptedSendNotSupported if encrypted send is requested but not supported by CLI // Returns ErrEncryptedSendNotSupported if encrypted send is requested but not supported by CLI
func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*ReadCloserCopier, error) { func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*SendStream, error) {
args := make([]string, 0) args := make([]string, 0)
args = append(args, "send") args = append(args, "send")
@ -879,14 +822,14 @@ func ZFSSend(ctx context.Context, sendArgs ZFSSendArgsValidated) (*ReadCloserCop
// close our writing-end of the pipe so that we don't wait for ourselves when reading from the reading end // close our writing-end of the pipe so that we don't wait for ourselves when reading from the reading end
stdoutWriter.Close() stdoutWriter.Close()
stream := &sendStream{ stream := &SendStream{
cmd: cmd, cmd: cmd,
kill: cancel, kill: cancel,
stdoutReader: stdoutReader, stdoutReader: stdoutReader,
stderrBuf: stderrBuf, stderrBuf: stderrBuf,
} }
return NewReadCloserCopier(stream), nil return stream, nil
} }
type DrySendType string type DrySendType string
@ -1025,24 +968,6 @@ func ZFSSendDry(ctx context.Context, sendArgs ZFSSendArgsValidated) (_ *DrySendI
return &si, nil return &si, nil
} }
type StreamCopierError interface {
error
IsReadError() bool
IsWriteError() bool
}
type StreamCopier interface {
// WriteStreamTo writes the stream represented by this StreamCopier
// to the given io.Writer.
WriteStreamTo(w io.Writer) StreamCopierError
// Close must be called as soon as it is clear that no more data will
// be read from the StreamCopier.
// If StreamCopier gets its data from a connection, it might hold
// a lock on the connection until Close is called. Only closing ensures
// that the connection can be used afterwards.
Close() error
}
type RecvOptions struct { type RecvOptions struct {
// Rollback to the oldest snapshot, destroy it, then perform `recv -F`. // Rollback to the oldest snapshot, destroy it, then perform `recv -F`.
// Note that this doesn't change property values, i.e. an existing local property value will be kept. // Note that this doesn't change property values, i.e. an existing local property value will be kept.
@ -1067,7 +992,9 @@ func (e *ErrRecvResumeNotSupported) Error() string {
return buf.String() return buf.String()
} }
func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier StreamCopier, opts RecvOptions) (err error) { const RecvStderrBufSiz = 1 << 15
func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, stream io.ReadCloser, opts RecvOptions) (err error) {
if err := v.ValidateInMemory(fs); err != nil { if err := v.ValidateInMemory(fs); err != nil {
return errors.Wrap(err, "invalid version") return errors.Wrap(err, "invalid version")
@ -1134,7 +1061,7 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
// cannot receive new filesystem stream: invalid backup stream // cannot receive new filesystem stream: invalid backup stream
stdout := bytes.NewBuffer(make([]byte, 0, 1024)) stdout := bytes.NewBuffer(make([]byte, 0, 1024))
stderr := bytes.NewBuffer(make([]byte, 0, 1024)) stderr := bytes.NewBuffer(make([]byte, 0, RecvStderrBufSiz))
stdin, stdinWriter, err := pipeWithCapacityHint(ZFSRecvPipeCapacityHint) stdin, stdinWriter, err := pipeWithCapacityHint(ZFSRecvPipeCapacityHint)
if err != nil { if err != nil {
@ -1162,9 +1089,10 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
debug("started") debug("started")
copierErrChan := make(chan StreamCopierError) copierErrChan := make(chan error)
go func() { go func() {
copierErrChan <- streamCopier.WriteStreamTo(stdinWriter) _, err := io.Copy(stdinWriter, stream)
copierErrChan <- err
stdinWriter.Close() stdinWriter.Close()
}() }()
waitErrChan := make(chan error) waitErrChan := make(chan error)
@ -1173,6 +1101,10 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
if err = cmd.Wait(); err != nil { if err = cmd.Wait(); err != nil {
if rtErr := tryRecvErrorWithResumeToken(ctx, stderr.String()); rtErr != nil { if rtErr := tryRecvErrorWithResumeToken(ctx, stderr.String()); rtErr != nil {
waitErrChan <- rtErr waitErrChan <- rtErr
} else if owErr := tryRecvDestroyOrOverwriteEncryptedErr(stderr.Bytes()); owErr != nil {
waitErrChan <- owErr
} else if readErr := tryRecvCannotReadFromStreamErr(stderr.Bytes()); readErr != nil {
waitErrChan <- readErr
} else { } else {
waitErrChan <- &ZFSError{ waitErrChan <- &ZFSError{
Stderr: stderr.Bytes(), Stderr: stderr.Bytes(),
@ -1183,22 +1115,23 @@ func ZFSRecv(ctx context.Context, fs string, v *ZFSSendArgVersion, streamCopier
} }
}() }()
// streamCopier always fails before or simultaneously with Wait
// thus receive from it first
copierErr := <-copierErrChan copierErr := <-copierErrChan
debug("copierErr: %T %s", copierErr, copierErr) debug("copierErr: %T %s", copierErr, copierErr)
if copierErr != nil { if copierErr != nil {
debug("killing zfs recv command after copierErr")
cancelCmd() cancelCmd()
} }
waitErr := <-waitErrChan waitErr := <-waitErrChan
debug("waitErr: %T %s", waitErr, waitErr) debug("waitErr: %T %s", waitErr, waitErr)
if copierErr == nil && waitErr == nil { if copierErr == nil && waitErr == nil {
return nil return nil
} else if waitErr != nil && (copierErr == nil || copierErr.IsWriteError()) { } else if _, isReadErr := waitErr.(*RecvCannotReadFromStreamErr); isReadErr {
return waitErr // has more interesting info in that case return copierErr // likely network error reading from stream
} else {
return waitErr // almost always more interesting info. NOTE: do not wrap!
} }
return copierErr // if it's not a write error, the copier error is more interesting
} }
type RecvFailedWithResumeTokenErr struct { type RecvFailedWithResumeTokenErr struct {
@ -1228,6 +1161,43 @@ func (e *RecvFailedWithResumeTokenErr) Error() string {
return fmt.Sprintf("receive failed, resume token available: %s\n%#v", e.ResumeTokenRaw, e.ResumeTokenParsed) return fmt.Sprintf("receive failed, resume token available: %s\n%#v", e.ResumeTokenRaw, e.ResumeTokenParsed)
} }
type RecvDestroyOrOverwriteEncryptedErr struct {
Msg string
}
func (e *RecvDestroyOrOverwriteEncryptedErr) Error() string {
return e.Msg
}
var recvDestroyOrOverwriteEncryptedErrRe = regexp.MustCompile(`^(cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one)`)
func tryRecvDestroyOrOverwriteEncryptedErr(stderr []byte) *RecvDestroyOrOverwriteEncryptedErr {
debug("tryRecvDestroyOrOverwriteEncryptedErr: %v", stderr)
m := recvDestroyOrOverwriteEncryptedErrRe.FindSubmatch(stderr)
if m == nil {
return nil
}
return &RecvDestroyOrOverwriteEncryptedErr{Msg: string(m[1])}
}
type RecvCannotReadFromStreamErr struct {
Msg string
}
func (e *RecvCannotReadFromStreamErr) Error() string {
return e.Msg
}
var reRecvCannotReadFromStreamErr = regexp.MustCompile(`^(cannot receive: failed to read from stream)$`)
func tryRecvCannotReadFromStreamErr(stderr []byte) *RecvCannotReadFromStreamErr {
m := reRecvCannotReadFromStreamErr.FindSubmatch(stderr)
if m == nil {
return nil
}
return &RecvCannotReadFromStreamErr{Msg: string(m[1])}
}
type ClearResumeTokenError struct { type ClearResumeTokenError struct {
ZFSOutput []byte ZFSOutput []byte
CmdError error CmdError error

View File

@ -2,9 +2,11 @@ package zfs
import ( import (
"context" "context"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestZFSListHandlesProducesZFSErrorOnNonZeroExit(t *testing.T) { func TestZFSListHandlesProducesZFSErrorOnNonZeroExit(t *testing.T) {
@ -259,3 +261,12 @@ size 10518512
}) })
} }
} }
func TestTryRecvDestroyOrOverwriteEncryptedErr(t *testing.T) {
msg := "cannot receive new filesystem stream: zfs receive -F cannot be used to destroy an encrypted filesystem or overwrite an unencrypted one with an encrypted one\n"
assert.GreaterOrEqual(t, RecvStderrBufSiz, len(msg))
err := tryRecvDestroyOrOverwriteEncryptedErr([]byte(msg))
require.NotNil(t, err)
assert.EqualError(t, err, strings.TrimSpace(msg))
}

View File

@ -5,11 +5,14 @@ import (
"bytes" "bytes"
"context" "context"
"io" "io"
"os"
"os/exec" "os/exec"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zrepl/zrepl/util/circlog"
) )
const testBin = "./zfscmd_platform_test.bash" const testBin = "./zfscmd_platform_test.bash"
@ -85,5 +88,37 @@ func TestCmdProcessState(t *testing.T) {
require.True(t, ok) require.True(t, ok)
require.NotNil(t, ee.ProcessState) require.NotNil(t, ee.ProcessState)
require.Contains(t, ee.Error(), "killed") require.Contains(t, ee.Error(), "killed")
}
func TestSigpipe(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cmd := CommandContext(ctx, "bash", "-c", "sleep 5; echo invalid input; exit 23")
r, w, err := os.Pipe()
require.NoError(t, err)
output := circlog.MustNewCircularLog(1 << 20)
cmd.SetStdio(Stdio{
Stdin: r,
Stdout: output,
Stderr: output,
})
err = cmd.Start()
require.NoError(t, err)
err = r.Close()
require.NoError(t, err)
// the script doesn't read stdin, but this input is almost certainly smaller than the pipe buffer
const LargerThanPipeBuffer = 1 << 21
_, err = io.Copy(w, bytes.NewBuffer(bytes.Repeat([]byte("i"), LargerThanPipeBuffer)))
// => io.Copy is going to block because the pipe buffer is full and the
// script is not reading from it
// => the script is going to exit after 5s
// => we should expect a broken pipe error from the copier's perspective
t.Logf("copy err = %T: %s", err, err)
require.NotNil(t, err)
require.True(t, strings.Contains(err.Error(), "broken pipe"))
err = cmd.Wait()
require.EqualError(t, err, "exit status 23")
} }