mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-21 16:03:32 +01:00
sshbytestream: fix semantics when ssh connection dies
in the Wait()ing goroutine, we create an instance of our own error containing WaitErr + stderr excerpt. Only if the remote command exits non-zero. io.EOF is the error we get as soon as the other end of the pipe (the ssh command) has died. => wait for it in the ReadWriter methods.
This commit is contained in:
parent
226935ddea
commit
ec4284f80c
@ -10,6 +10,15 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Stderr []byte
|
||||
WaitErr error
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
|
||||
}
|
||||
|
||||
type SSHTransport struct {
|
||||
Host string
|
||||
User string
|
||||
@ -37,8 +46,13 @@ func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (f IncomingReadWriteCloser) Close() (err error) {
|
||||
os.Exit(0)
|
||||
return nil
|
||||
if err = os.Stdin.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = os.Stdout.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
@ -68,20 +82,18 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
|
||||
var in io.WriteCloser
|
||||
var out io.ReadCloser
|
||||
var stderr io.Reader
|
||||
|
||||
if in, err = cmd.StdinPipe(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if out, err = cmd.StdoutPipe(); err != nil {
|
||||
return
|
||||
}
|
||||
if stderr, err = cmd.StderrPipe(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f := ForkedSSHReadWriteCloser{
|
||||
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||
cmd.Stderr = stderrBuf
|
||||
|
||||
f := &ForkedSSHReadWriteCloser{
|
||||
RemoteStdin: in,
|
||||
RemoteStdout: out,
|
||||
Cancel: cancel,
|
||||
@ -96,13 +108,16 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
|
||||
go func() {
|
||||
defer f.exitWaitGroup.Done()
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, stderr); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// stderr output is only relevant for errors if the exit code is non-zero
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b)
|
||||
//panic(err) TODO
|
||||
f.SSHCommandError = Error{
|
||||
Stderr: stderrBuf.Bytes(),
|
||||
WaitErr: err,
|
||||
}
|
||||
// fmt.Fprintf(os.Stderr, "ssh conn wait err: %#v\n", f.SSHCommandError.(Error))
|
||||
} else {
|
||||
f.SSHCommandError = nil
|
||||
}
|
||||
}()
|
||||
|
||||
@ -110,23 +125,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
}
|
||||
|
||||
type ForkedSSHReadWriteCloser struct {
|
||||
RemoteStdin io.Writer
|
||||
RemoteStdout io.Reader
|
||||
Command *exec.Cmd
|
||||
Cancel context.CancelFunc
|
||||
exitWaitGroup *sync.WaitGroup
|
||||
RemoteStdin io.Writer
|
||||
RemoteStdout io.Reader
|
||||
Command *exec.Cmd
|
||||
Cancel context.CancelFunc
|
||||
exitWaitGroup *sync.WaitGroup
|
||||
SSHCommandError error
|
||||
}
|
||||
|
||||
func (f ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||
return f.RemoteStdout.Read(p)
|
||||
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||
if f.SSHCommandError != nil {
|
||||
return 0, f.SSHCommandError
|
||||
}
|
||||
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
|
||||
// the ssh command has exited, but we need to wait for post-portem to finish
|
||||
f.exitWaitGroup.Wait()
|
||||
if f.SSHCommandError != nil {
|
||||
err = f.SSHCommandError
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return f.RemoteStdin.Write(p)
|
||||
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
if f.SSHCommandError != nil {
|
||||
return 0, f.SSHCommandError
|
||||
}
|
||||
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
|
||||
// the ssh command has exited, but we need to wait for post-portem to finish
|
||||
f.exitWaitGroup.Wait()
|
||||
if f.SSHCommandError != nil {
|
||||
err = f.SSHCommandError
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (f ForkedSSHReadWriteCloser) Close() (err error) {
|
||||
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
|
||||
// TODO should check SSHCommandError?
|
||||
f.Cancel()
|
||||
f.exitWaitGroup.Wait()
|
||||
return nil
|
||||
return f.SSHCommandError
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user