Simplify "fork then io.Reader" abstractions

This commit is contained in:
Christian Schwarz 2017-05-08 21:28:18 +02:00
parent 54778c0374
commit dd6dd60e98
3 changed files with 66 additions and 103 deletions

View File

@ -7,7 +7,6 @@ import (
"io"
"os"
"os/exec"
"sync"
)
type Error struct {
@ -55,9 +54,7 @@ func (f IncomingReadWriteCloser) Close() (err error) {
return
}
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
ctx, cancel := context.WithCancel(context.Background())
func Outgoing(remote SSHTransport) (f *ForkExecReadWriter, err error) {
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
sshArgs = append(sshArgs,
@ -75,6 +72,7 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
if len(remote.SSHCommand) > 0 {
sshCommand = SSHCommand
}
ctx, cancel := context.WithCancel(context.Background())
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
// Clear environment of cmd
@ -93,72 +91,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
cmd.Stderr = stderrBuf
f := &ForkedSSHReadWriteCloser{
RemoteStdin: in,
RemoteStdout: out,
Cancel: cancel,
f = &ForkExecReadWriter{
Stdin: in,
Stdout: out,
Command: cmd,
exitWaitGroup: &sync.WaitGroup{},
CommandCancel: cancel,
StderrBuf: stderrBuf,
}
f.exitWaitGroup.Add(1)
if err = cmd.Start(); err != nil {
return
}
err = cmd.Start()
return
}
go func() {
defer f.exitWaitGroup.Done()
type ForkExecReadWriter struct {
Command *exec.Cmd
CommandCancel context.CancelFunc
Stdin io.Writer
Stdout io.Reader
StderrBuf *bytes.Buffer
}
// stderr output is only relevant for errors if the exit code is non-zero
if err := cmd.Wait(); err != nil {
f.SSHCommandError = Error{
Stderr: stderrBuf.Bytes(),
WaitErr: err,
func (f *ForkExecReadWriter) Read(buf []byte) (n int, err error) {
n, err = f.Stdout.Read(buf)
if err == io.EOF {
waitErr := f.Command.Wait()
if waitErr != nil {
err = Error{
WaitErr: waitErr,
Stderr: f.StderrBuf.Bytes(),
}
} else {
f.SSHCommandError = io.EOF
}
}()
return f, nil
}
type ForkedSSHReadWriteCloser struct {
RemoteStdin io.Writer
RemoteStdout io.Reader
Command *exec.Cmd
Cancel context.CancelFunc
exitWaitGroup *sync.WaitGroup
SSHCommandError error
}
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
err = f.SSHCommandError
}
return
}
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
err = f.SSHCommandError
}
return
func (f *ForkExecReadWriter) Write(p []byte) (n int, err error) {
return f.Stdin.Write(p)
}
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
// TODO should check SSHCommandError?
f.Cancel()
f.exitWaitGroup.Wait()
return f.SSHCommandError
func (f *ForkExecReadWriter) Close() error {
f.CommandCancel()
return nil
}

View File

@ -2,68 +2,60 @@ package zfs
import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"sync"
)
// A ForkReader is an io.Reader for a forked process's stdout.
// It Wait()s for the process to exit and - if it exits with error - returns this exit error
// on subsequent Read()s.
type ForkReader struct {
cancelFunc context.CancelFunc
cmd *exec.Cmd
stdout io.Reader
waitErr error
exitWaitGroup sync.WaitGroup
type ForkExecReader struct {
Cmd *exec.Cmd
InStream io.Reader
StderrBuf *bytes.Buffer
}
func NewForkReader(command string, args ...string) (r *ForkReader, err error) {
func NewForkExecReader(command string, args ...string) (r *ForkExecReader, err error) {
r = &ForkReader{}
r = &ForkExecReader{}
var ctx context.Context
ctx, r.cancelFunc = context.WithCancel(context.Background())
r.Cmd = exec.Command(command, args...)
cmd := exec.CommandContext(ctx, command, args...)
stderr := bytes.NewBuffer(make([]byte, 0, 1024))
cmd.Stderr = stderr
if r.stdout, err = cmd.StdoutPipe(); err != nil {
r.InStream, err = r.Cmd.StdoutPipe()
if err != nil {
return
}
if err = cmd.Start(); err != nil {
r.StderrBuf = bytes.NewBuffer(make([]byte, 0, 1024))
r.Cmd.Stderr = r.StderrBuf
if err = r.Cmd.Start(); err != nil {
return
}
r.exitWaitGroup.Add(1)
go func() {
defer r.exitWaitGroup.Done()
if err := cmd.Wait(); err != nil {
os.Stderr.WriteString(err.Error())
r.waitErr = ZFSError{
Stderr: stderr.Bytes(),
WaitErr: err,
return
}
type ForkExecReaderError struct {
WaitErr error
Stderr []byte
}
func (e ForkExecReaderError) Error() string {
return fmt.Sprintf("underlying process exited with error: %s\nstderr: %s\n", e.WaitErr, e.Stderr)
}
func (t *ForkExecReader) Read(buf []byte) (n int, err error) {
n, err = t.InStream.Read(buf)
if err == io.EOF {
waitErr := t.Cmd.Wait()
if waitErr != nil {
err = ForkExecReaderError{
WaitErr: waitErr,
Stderr: t.StderrBuf.Bytes(),
}
return
}
}()
return
}
func (r *ForkReader) Read(buf []byte) (n int, err error) {
if r.waitErr != nil {
return 0, r.waitErr
}
if n, err = r.stdout.Read(buf); err == io.EOF {
// the command has exited but we need to wait for Wait()ing goroutine to finish
r.exitWaitGroup.Wait()
if r.waitErr != nil {
err = r.waitErr
}
}
return

View File

@ -114,7 +114,7 @@ func ZFSSend(fs DatasetPath, from, to *FilesystemVersion) (stream io.Reader, err
args = append(args, "-i", from.ToAbsPath(fs), to.ToAbsPath(fs))
}
stream, err = NewForkReader(ZFS_BINARY, args...)
stream, err = NewForkExecReader(ZFS_BINARY, args...)
return
}