diff --git a/sshbytestream/ssh.go b/sshbytestream/ssh.go index 2a226c3..23c745f 100644 --- a/sshbytestream/ssh.go +++ b/sshbytestream/ssh.go @@ -7,7 +7,6 @@ import ( "io" "os" "os/exec" - "sync" ) type Error struct { @@ -55,9 +54,7 @@ func (f IncomingReadWriteCloser) Close() (err error) { return } -func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { - - ctx, cancel := context.WithCancel(context.Background()) +func Outgoing(remote SSHTransport) (f *ForkExecReadWriter, err error) { sshArgs := make([]string, 0, 2*len(remote.Options)+4) sshArgs = append(sshArgs, @@ -75,6 +72,7 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { if len(remote.SSHCommand) > 0 { sshCommand = SSHCommand } + ctx, cancel := context.WithCancel(context.Background()) cmd := exec.CommandContext(ctx, sshCommand, sshArgs...) // Clear environment of cmd @@ -93,72 +91,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024)) cmd.Stderr = stderrBuf - f := &ForkedSSHReadWriteCloser{ - RemoteStdin: in, - RemoteStdout: out, - Cancel: cancel, + f = &ForkExecReadWriter{ + Stdin: in, + Stdout: out, Command: cmd, - exitWaitGroup: &sync.WaitGroup{}, + CommandCancel: cancel, + StderrBuf: stderrBuf, } - f.exitWaitGroup.Add(1) - if err = cmd.Start(); err != nil { - return - } + err = cmd.Start() + return +} - go func() { - defer f.exitWaitGroup.Done() +type ForkExecReadWriter struct { + Command *exec.Cmd + CommandCancel context.CancelFunc + Stdin io.Writer + Stdout io.Reader + StderrBuf *bytes.Buffer +} - // stderr output is only relevant for errors if the exit code is non-zero - if err := cmd.Wait(); err != nil { - f.SSHCommandError = Error{ - Stderr: stderrBuf.Bytes(), - WaitErr: err, +func (f *ForkExecReadWriter) Read(buf []byte) (n int, err error) { + n, err = f.Stdout.Read(buf) + if err == io.EOF { + waitErr := f.Command.Wait() + if waitErr != nil { + err = Error{ + WaitErr: waitErr, + Stderr: f.StderrBuf.Bytes(), } - } else { - f.SSHCommandError = io.EOF } - }() - - return f, nil -} - -type ForkedSSHReadWriteCloser struct { - RemoteStdin io.Writer - RemoteStdout io.Reader - Command *exec.Cmd - Cancel context.CancelFunc - exitWaitGroup *sync.WaitGroup - SSHCommandError error -} - -func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) { - if f.SSHCommandError != nil { - return 0, f.SSHCommandError - } - if n, err = f.RemoteStdout.Read(p); err == io.EOF { - // the ssh command has exited, but we need to wait for post-portem to finish - f.exitWaitGroup.Wait() - err = f.SSHCommandError } return } -func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) { - if f.SSHCommandError != nil { - return 0, f.SSHCommandError - } - if n, err = f.RemoteStdin.Write(p); err == io.EOF { - // the ssh command has exited, but we need to wait for post-portem to finish - f.exitWaitGroup.Wait() - err = f.SSHCommandError - } - return +func (f *ForkExecReadWriter) Write(p []byte) (n int, err error) { + return f.Stdin.Write(p) } -func (f *ForkedSSHReadWriteCloser) Close() (err error) { - // TODO should check SSHCommandError? - f.Cancel() - f.exitWaitGroup.Wait() - return f.SSHCommandError +func (f *ForkExecReadWriter) Close() error { + f.CommandCancel() + return nil } diff --git a/zfs/fork_reader.go b/zfs/fork_reader.go index 0d0bd82..f6126c6 100644 --- a/zfs/fork_reader.go +++ b/zfs/fork_reader.go @@ -2,68 +2,60 @@ package zfs import ( "bytes" - "context" + "fmt" "io" - "os" "os/exec" - "sync" ) // A ForkReader is an io.Reader for a forked process's stdout. // It Wait()s for the process to exit and - if it exits with error - returns this exit error // on subsequent Read()s. -type ForkReader struct { - cancelFunc context.CancelFunc - cmd *exec.Cmd - stdout io.Reader - waitErr error - exitWaitGroup sync.WaitGroup +type ForkExecReader struct { + Cmd *exec.Cmd + InStream io.Reader + StderrBuf *bytes.Buffer } -func NewForkReader(command string, args ...string) (r *ForkReader, err error) { +func NewForkExecReader(command string, args ...string) (r *ForkExecReader, err error) { - r = &ForkReader{} + r = &ForkExecReader{} - var ctx context.Context - ctx, r.cancelFunc = context.WithCancel(context.Background()) + r.Cmd = exec.Command(command, args...) - cmd := exec.CommandContext(ctx, command, args...) - - stderr := bytes.NewBuffer(make([]byte, 0, 1024)) - cmd.Stderr = stderr - - if r.stdout, err = cmd.StdoutPipe(); err != nil { + r.InStream, err = r.Cmd.StdoutPipe() + if err != nil { return } - if err = cmd.Start(); err != nil { + r.StderrBuf = bytes.NewBuffer(make([]byte, 0, 1024)) + r.Cmd.Stderr = r.StderrBuf + + if err = r.Cmd.Start(); err != nil { return } - r.exitWaitGroup.Add(1) - go func() { - defer r.exitWaitGroup.Done() - if err := cmd.Wait(); err != nil { - os.Stderr.WriteString(err.Error()) - r.waitErr = ZFSError{ - Stderr: stderr.Bytes(), - WaitErr: err, + return + +} + +type ForkExecReaderError struct { + WaitErr error + Stderr []byte +} + +func (e ForkExecReaderError) Error() string { + return fmt.Sprintf("underlying process exited with error: %s\nstderr: %s\n", e.WaitErr, e.Stderr) +} + +func (t *ForkExecReader) Read(buf []byte) (n int, err error) { + n, err = t.InStream.Read(buf) + if err == io.EOF { + waitErr := t.Cmd.Wait() + if waitErr != nil { + err = ForkExecReaderError{ + WaitErr: waitErr, + Stderr: t.StderrBuf.Bytes(), } - return - } - }() - return -} - -func (r *ForkReader) Read(buf []byte) (n int, err error) { - if r.waitErr != nil { - return 0, r.waitErr - } - if n, err = r.stdout.Read(buf); err == io.EOF { - // the command has exited but we need to wait for Wait()ing goroutine to finish - r.exitWaitGroup.Wait() - if r.waitErr != nil { - err = r.waitErr } } return diff --git a/zfs/zfs.go b/zfs/zfs.go index 115afad..7b487f3 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -114,7 +114,7 @@ func ZFSSend(fs DatasetPath, from, to *FilesystemVersion) (stream io.Reader, err args = append(args, "-i", from.ToAbsPath(fs), to.ToAbsPath(fs)) } - stream, err = NewForkReader(ZFS_BINARY, args...) + stream, err = NewForkExecReader(ZFS_BINARY, args...) return }