sshbytestream: fix semantics when ssh connection dies

in the Wait()ing goroutine, we create an instance of our own error
containing WaitErr + stderr excerpt. Only if the remote command exits
non-zero.

io.EOF is the error we get as soon as the other end of the pipe
(the ssh command) has died.
=> wait for it in the ReadWriter methods.
This commit is contained in:
Christian Schwarz 2017-04-30 23:35:08 +02:00
parent 226935ddea
commit ec4284f80c

View File

@ -10,6 +10,15 @@ import (
"sync" "sync"
) )
type Error struct {
Stderr []byte
WaitErr error
}
func (e Error) Error() string {
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
}
type SSHTransport struct { type SSHTransport struct {
Host string Host string
User string User string
@ -37,8 +46,13 @@ func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
} }
func (f IncomingReadWriteCloser) Close() (err error) { func (f IncomingReadWriteCloser) Close() (err error) {
os.Exit(0) if err = os.Stdin.Close(); err != nil {
return nil return
}
if err = os.Stdout.Close(); err != nil {
return
}
return
} }
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
@ -68,20 +82,18 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
var in io.WriteCloser var in io.WriteCloser
var out io.ReadCloser var out io.ReadCloser
var stderr io.Reader
if in, err = cmd.StdinPipe(); err != nil { if in, err = cmd.StdinPipe(); err != nil {
return return
} }
if out, err = cmd.StdoutPipe(); err != nil { if out, err = cmd.StdoutPipe(); err != nil {
return return
} }
if stderr, err = cmd.StderrPipe(); err != nil {
return
}
f := ForkedSSHReadWriteCloser{ stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
cmd.Stderr = stderrBuf
f := &ForkedSSHReadWriteCloser{
RemoteStdin: in, RemoteStdin: in,
RemoteStdout: out, RemoteStdout: out,
Cancel: cancel, Cancel: cancel,
@ -96,13 +108,16 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
go func() { go func() {
defer f.exitWaitGroup.Done() defer f.exitWaitGroup.Done()
var b bytes.Buffer
if _, err := io.Copy(&b, stderr); err != nil { // stderr output is only relevant for errors if the exit code is non-zero
panic(err)
}
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b) f.SSHCommandError = Error{
//panic(err) TODO Stderr: stderrBuf.Bytes(),
WaitErr: err,
}
// fmt.Fprintf(os.Stderr, "ssh conn wait err: %#v\n", f.SSHCommandError.(Error))
} else {
f.SSHCommandError = nil
} }
}() }()
@ -110,23 +125,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
} }
type ForkedSSHReadWriteCloser struct { type ForkedSSHReadWriteCloser struct {
RemoteStdin io.Writer RemoteStdin io.Writer
RemoteStdout io.Reader RemoteStdout io.Reader
Command *exec.Cmd Command *exec.Cmd
Cancel context.CancelFunc Cancel context.CancelFunc
exitWaitGroup *sync.WaitGroup exitWaitGroup *sync.WaitGroup
SSHCommandError error
} }
func (f ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) { func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
return f.RemoteStdout.Read(p) if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
if f.SSHCommandError != nil {
err = f.SSHCommandError
}
}
return
} }
func (f ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) { func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
return f.RemoteStdin.Write(p) if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
if f.SSHCommandError != nil {
err = f.SSHCommandError
}
}
return
} }
func (f ForkedSSHReadWriteCloser) Close() (err error) { func (f *ForkedSSHReadWriteCloser) Close() (err error) {
// TODO should check SSHCommandError?
f.Cancel() f.Cancel()
f.exitWaitGroup.Wait() f.exitWaitGroup.Wait()
return nil return f.SSHCommandError
} }