mirror of
https://github.com/zrepl/zrepl.git
synced 2025-04-02 20:16:45 +02:00
Simplify "fork then io.Reader" abstractions
This commit is contained in:
parent
54778c0374
commit
dd6dd60e98
@ -7,7 +7,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
@ -55,9 +54,7 @@ func (f IncomingReadWriteCloser) Close() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
func Outgoing(remote SSHTransport) (f *ForkExecReadWriter, err error) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
|
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
|
||||||
sshArgs = append(sshArgs,
|
sshArgs = append(sshArgs,
|
||||||
@ -75,6 +72,7 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
|||||||
if len(remote.SSHCommand) > 0 {
|
if len(remote.SSHCommand) > 0 {
|
||||||
sshCommand = SSHCommand
|
sshCommand = SSHCommand
|
||||||
}
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
|
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
|
||||||
|
|
||||||
// Clear environment of cmd
|
// Clear environment of cmd
|
||||||
@ -93,72 +91,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
|||||||
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
|
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||||
cmd.Stderr = stderrBuf
|
cmd.Stderr = stderrBuf
|
||||||
|
|
||||||
f := &ForkedSSHReadWriteCloser{
|
f = &ForkExecReadWriter{
|
||||||
RemoteStdin: in,
|
Stdin: in,
|
||||||
RemoteStdout: out,
|
Stdout: out,
|
||||||
Cancel: cancel,
|
|
||||||
Command: cmd,
|
Command: cmd,
|
||||||
exitWaitGroup: &sync.WaitGroup{},
|
CommandCancel: cancel,
|
||||||
|
StderrBuf: stderrBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
f.exitWaitGroup.Add(1)
|
err = cmd.Start()
|
||||||
if err = cmd.Start(); err != nil {
|
return
|
||||||
return
|
}
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
type ForkExecReadWriter struct {
|
||||||
defer f.exitWaitGroup.Done()
|
Command *exec.Cmd
|
||||||
|
CommandCancel context.CancelFunc
|
||||||
|
Stdin io.Writer
|
||||||
|
Stdout io.Reader
|
||||||
|
StderrBuf *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
// stderr output is only relevant for errors if the exit code is non-zero
|
func (f *ForkExecReadWriter) Read(buf []byte) (n int, err error) {
|
||||||
if err := cmd.Wait(); err != nil {
|
n, err = f.Stdout.Read(buf)
|
||||||
f.SSHCommandError = Error{
|
if err == io.EOF {
|
||||||
Stderr: stderrBuf.Bytes(),
|
waitErr := f.Command.Wait()
|
||||||
WaitErr: err,
|
if waitErr != nil {
|
||||||
|
err = Error{
|
||||||
|
WaitErr: waitErr,
|
||||||
|
Stderr: f.StderrBuf.Bytes(),
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
f.SSHCommandError = io.EOF
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ForkedSSHReadWriteCloser struct {
|
|
||||||
RemoteStdin io.Writer
|
|
||||||
RemoteStdout io.Reader
|
|
||||||
Command *exec.Cmd
|
|
||||||
Cancel context.CancelFunc
|
|
||||||
exitWaitGroup *sync.WaitGroup
|
|
||||||
SSHCommandError error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
|
|
||||||
if f.SSHCommandError != nil {
|
|
||||||
return 0, f.SSHCommandError
|
|
||||||
}
|
|
||||||
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
|
|
||||||
// the ssh command has exited, but we need to wait for post-portem to finish
|
|
||||||
f.exitWaitGroup.Wait()
|
|
||||||
err = f.SSHCommandError
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
|
func (f *ForkExecReadWriter) Write(p []byte) (n int, err error) {
|
||||||
if f.SSHCommandError != nil {
|
return f.Stdin.Write(p)
|
||||||
return 0, f.SSHCommandError
|
|
||||||
}
|
|
||||||
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
|
|
||||||
// the ssh command has exited, but we need to wait for post-portem to finish
|
|
||||||
f.exitWaitGroup.Wait()
|
|
||||||
err = f.SSHCommandError
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
|
func (f *ForkExecReadWriter) Close() error {
|
||||||
// TODO should check SSHCommandError?
|
f.CommandCancel()
|
||||||
f.Cancel()
|
return nil
|
||||||
f.exitWaitGroup.Wait()
|
|
||||||
return f.SSHCommandError
|
|
||||||
}
|
}
|
||||||
|
@ -2,68 +2,60 @@ package zfs
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A ForkReader is an io.Reader for a forked process's stdout.
|
// A ForkReader is an io.Reader for a forked process's stdout.
|
||||||
// It Wait()s for the process to exit and - if it exits with error - returns this exit error
|
// It Wait()s for the process to exit and - if it exits with error - returns this exit error
|
||||||
// on subsequent Read()s.
|
// on subsequent Read()s.
|
||||||
type ForkReader struct {
|
type ForkExecReader struct {
|
||||||
cancelFunc context.CancelFunc
|
Cmd *exec.Cmd
|
||||||
cmd *exec.Cmd
|
InStream io.Reader
|
||||||
stdout io.Reader
|
StderrBuf *bytes.Buffer
|
||||||
waitErr error
|
|
||||||
exitWaitGroup sync.WaitGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewForkReader(command string, args ...string) (r *ForkReader, err error) {
|
func NewForkExecReader(command string, args ...string) (r *ForkExecReader, err error) {
|
||||||
|
|
||||||
r = &ForkReader{}
|
r = &ForkExecReader{}
|
||||||
|
|
||||||
var ctx context.Context
|
r.Cmd = exec.Command(command, args...)
|
||||||
ctx, r.cancelFunc = context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, command, args...)
|
r.InStream, err = r.Cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
stderr := bytes.NewBuffer(make([]byte, 0, 1024))
|
|
||||||
cmd.Stderr = stderr
|
|
||||||
|
|
||||||
if r.stdout, err = cmd.StdoutPipe(); err != nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = cmd.Start(); err != nil {
|
r.StderrBuf = bytes.NewBuffer(make([]byte, 0, 1024))
|
||||||
|
r.Cmd.Stderr = r.StderrBuf
|
||||||
|
|
||||||
|
if err = r.Cmd.Start(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.exitWaitGroup.Add(1)
|
|
||||||
|
|
||||||
go func() {
|
return
|
||||||
defer r.exitWaitGroup.Done()
|
|
||||||
if err := cmd.Wait(); err != nil {
|
}
|
||||||
os.Stderr.WriteString(err.Error())
|
|
||||||
r.waitErr = ZFSError{
|
type ForkExecReaderError struct {
|
||||||
Stderr: stderr.Bytes(),
|
WaitErr error
|
||||||
WaitErr: err,
|
Stderr []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e ForkExecReaderError) Error() string {
|
||||||
|
return fmt.Sprintf("underlying process exited with error: %s\nstderr: %s\n", e.WaitErr, e.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ForkExecReader) Read(buf []byte) (n int, err error) {
|
||||||
|
n, err = t.InStream.Read(buf)
|
||||||
|
if err == io.EOF {
|
||||||
|
waitErr := t.Cmd.Wait()
|
||||||
|
if waitErr != nil {
|
||||||
|
err = ForkExecReaderError{
|
||||||
|
WaitErr: waitErr,
|
||||||
|
Stderr: t.StderrBuf.Bytes(),
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ForkReader) Read(buf []byte) (n int, err error) {
|
|
||||||
if r.waitErr != nil {
|
|
||||||
return 0, r.waitErr
|
|
||||||
}
|
|
||||||
if n, err = r.stdout.Read(buf); err == io.EOF {
|
|
||||||
// the command has exited but we need to wait for Wait()ing goroutine to finish
|
|
||||||
r.exitWaitGroup.Wait()
|
|
||||||
if r.waitErr != nil {
|
|
||||||
err = r.waitErr
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
@ -114,7 +114,7 @@ func ZFSSend(fs DatasetPath, from, to *FilesystemVersion) (stream io.Reader, err
|
|||||||
args = append(args, "-i", from.ToAbsPath(fs), to.ToAbsPath(fs))
|
args = append(args, "-i", from.ToAbsPath(fs), to.ToAbsPath(fs))
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err = NewForkReader(ZFS_BINARY, args...)
|
stream, err = NewForkExecReader(ZFS_BINARY, args...)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user