diff --git a/sshbytestream/ssh.go b/sshbytestream/ssh.go index 23c745f..9716c0a 100644 --- a/sshbytestream/ssh.go +++ b/sshbytestream/ssh.go @@ -1,12 +1,10 @@ package sshbytestream import ( - "bytes" - "context" "fmt" + "github.com/zrepl/zrepl/util" "io" "os" - "os/exec" ) type Error struct { @@ -54,7 +52,7 @@ func (f IncomingReadWriteCloser) Close() (err error) { return } -func Outgoing(remote SSHTransport) (f *ForkExecReadWriter, err error) { +func Outgoing(remote SSHTransport) (c *util.IOCommand, err error) { sshArgs := make([]string, 0, 2*len(remote.Options)+4) sshArgs = append(sshArgs, @@ -72,64 +70,14 @@ func Outgoing(remote SSHTransport) (f *ForkExecReadWriter, err error) { if len(remote.SSHCommand) > 0 { sshCommand = SSHCommand } - ctx, cancel := context.WithCancel(context.Background()) - cmd := exec.CommandContext(ctx, sshCommand, sshArgs...) - // Clear environment of cmd - cmd.Env = []string{} - - var in io.WriteCloser - var out io.ReadCloser - - if in, err = cmd.StdinPipe(); err != nil { - return - } - if out, err = cmd.StdoutPipe(); err != nil { + if c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil { return } - stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024)) - cmd.Stderr = stderrBuf + // Clear environment of cmd, ssh shall not rely on SSH_AUTH_SOCK, etc. + c.Cmd.Env = []string{} - f = &ForkExecReadWriter{ - Stdin: in, - Stdout: out, - Command: cmd, - CommandCancel: cancel, - StderrBuf: stderrBuf, - } - - err = cmd.Start() + err = c.Start() return } - -type ForkExecReadWriter struct { - Command *exec.Cmd - CommandCancel context.CancelFunc - Stdin io.Writer - Stdout io.Reader - StderrBuf *bytes.Buffer -} - -func (f *ForkExecReadWriter) Read(buf []byte) (n int, err error) { - n, err = f.Stdout.Read(buf) - if err == io.EOF { - waitErr := f.Command.Wait() - if waitErr != nil { - err = Error{ - WaitErr: waitErr, - Stderr: f.StderrBuf.Bytes(), - } - } - } - return -} - -func (f *ForkExecReadWriter) Write(p []byte) (n int, err error) { - return f.Stdin.Write(p) -} - -func (f *ForkExecReadWriter) Close() error { - f.CommandCancel() - return nil -} diff --git a/util/iocommand.go b/util/iocommand.go new file mode 100644 index 0000000..3646ef1 --- /dev/null +++ b/util/iocommand.go @@ -0,0 +1,107 @@ +package util + +import ( + "bytes" + "context" + "fmt" + "io" + "os/exec" +) + +// An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface. +type IOCommand struct { + Cmd *exec.Cmd + CmdContext context.Context + CmdCancel context.CancelFunc + Stdin io.Writer + Stdout io.Reader + StderrBuf *bytes.Buffer +} + +const IOCommandStderrBufSize = 1024 + +type IOCommandError struct { + WaitErr error + Stderr []byte +} + +func (e IOCommandError) Error() string { + return fmt.Sprintf("underlying process exited with error: %s\nstderr: %s\n", e.WaitErr, e.Stderr) +} + +func RunIOCommand(command string, args ...string) (c *IOCommand, err error) { + c, err = NewIOCommand(command, args, IOCommandStderrBufSize) + if err != nil { + return + } + err = c.Start() + return +} + +func NewIOCommand(command string, args []string, stderrBufSize int) (c *IOCommand, err error) { + + if stderrBufSize == 0 { + stderrBufSize = IOCommandStderrBufSize + } + + c = &IOCommand{} + + c.CmdContext, c.CmdCancel = context.WithCancel(context.Background()) + c.Cmd = exec.CommandContext(c.CmdContext, command, args...) + + if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil { + return + } + + if c.Stdin, err = c.Cmd.StdinPipe(); err != nil { + return + } + + c.StderrBuf = bytes.NewBuffer(make([]byte, 0, stderrBufSize)) + c.Cmd.Stderr = c.StderrBuf + + return + +} + +func (c *IOCommand) Start() (err error) { + if err = c.Cmd.Start(); err != nil { + return + } + return +} + +// Read from process's stdout. +// The behavior after Close()ing is undefined +func (c *IOCommand) Read(buf []byte) (n int, err error) { + n, err = c.Stdout.Read(buf) + if err == io.EOF { + if waitErr := c.doWait(); waitErr != nil { + err = waitErr + } + } + return +} + +func (c *IOCommand) doWait() (err error) { + waitErr := c.Cmd.Wait() + if waitErr != nil { + err = IOCommandError{ + WaitErr: waitErr, + Stderr: c.StderrBuf.Bytes(), + } + } + return +} + +// Write to process's stdin. +// The behavior after Close()ing is undefined +func (c *IOCommand) Write(buf []byte) (n int, err error) { + return c.Stdin.Write(buf) +} + +// Kill the child process and collect its exit status +func (c *IOCommand) Close() error { + c.CmdCancel() + return c.doWait() +} diff --git a/zfs/fork_reader.go b/zfs/fork_reader.go deleted file mode 100644 index f6126c6..0000000 --- a/zfs/fork_reader.go +++ /dev/null @@ -1,62 +0,0 @@ -package zfs - -import ( - "bytes" - "fmt" - "io" - "os/exec" -) - -// A ForkReader is an io.Reader for a forked process's stdout. -// It Wait()s for the process to exit and - if it exits with error - returns this exit error -// on subsequent Read()s. -type ForkExecReader struct { - Cmd *exec.Cmd - InStream io.Reader - StderrBuf *bytes.Buffer -} - -func NewForkExecReader(command string, args ...string) (r *ForkExecReader, err error) { - - r = &ForkExecReader{} - - r.Cmd = exec.Command(command, args...) - - r.InStream, err = r.Cmd.StdoutPipe() - if err != nil { - return - } - - r.StderrBuf = bytes.NewBuffer(make([]byte, 0, 1024)) - r.Cmd.Stderr = r.StderrBuf - - if err = r.Cmd.Start(); err != nil { - return - } - - return - -} - -type ForkExecReaderError struct { - WaitErr error - Stderr []byte -} - -func (e ForkExecReaderError) Error() string { - return fmt.Sprintf("underlying process exited with error: %s\nstderr: %s\n", e.WaitErr, e.Stderr) -} - -func (t *ForkExecReader) Read(buf []byte) (n int, err error) { - n, err = t.InStream.Read(buf) - if err == io.EOF { - waitErr := t.Cmd.Wait() - if waitErr != nil { - err = ForkExecReaderError{ - WaitErr: waitErr, - Stderr: t.StderrBuf.Bytes(), - } - } - } - return -} diff --git a/zfs/zfs.go b/zfs/zfs.go index 7b487f3..75be8a3 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -5,6 +5,7 @@ import ( "bytes" "errors" "fmt" + "github.com/zrepl/zrepl/util" "io" "os/exec" "strings" @@ -114,7 +115,7 @@ func ZFSSend(fs DatasetPath, from, to *FilesystemVersion) (stream io.Reader, err args = append(args, "-i", from.ToAbsPath(fs), to.ToAbsPath(fs)) } - stream, err = NewForkExecReader(ZFS_BINARY, args...) + stream, err = util.RunIOCommand(ZFS_BINARY, args...) return }