sshbytestream: fix semantics when ssh connection dies

in the Wait()ing goroutine, we create an instance of our own error
containing WaitErr + stderr excerpt. Only if the remote command exits
non-zero.

io.EOF is the error we get as soon as the other end of the pipe
(the ssh command) has died.
=> wait for it in the ReadWriter methods.
This commit is contained in:
Christian Schwarz 2017-04-30 23:35:08 +02:00
parent 226935ddea
commit ec4284f80c

View File

@ -10,6 +10,15 @@ import (
"sync"
)
type Error struct {
Stderr []byte
WaitErr error
}
func (e Error) Error() string {
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
}
type SSHTransport struct {
Host string
User string
@ -37,8 +46,13 @@ func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
}
func (f IncomingReadWriteCloser) Close() (err error) {
os.Exit(0)
return nil
if err = os.Stdin.Close(); err != nil {
return
}
if err = os.Stdout.Close(); err != nil {
return
}
return
}
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
@ -68,20 +82,18 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
var in io.WriteCloser
var out io.ReadCloser
var stderr io.Reader
if in, err = cmd.StdinPipe(); err != nil {
return
}
if out, err = cmd.StdoutPipe(); err != nil {
return
}
if stderr, err = cmd.StderrPipe(); err != nil {
return
}
f := ForkedSSHReadWriteCloser{
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
cmd.Stderr = stderrBuf
f := &ForkedSSHReadWriteCloser{
RemoteStdin: in,
RemoteStdout: out,
Cancel: cancel,
@ -96,13 +108,16 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
go func() {
defer f.exitWaitGroup.Done()
var b bytes.Buffer
if _, err := io.Copy(&b, stderr); err != nil {
panic(err)
}
// stderr output is only relevant for errors if the exit code is non-zero
if err := cmd.Wait(); err != nil {
fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b)
//panic(err) TODO
f.SSHCommandError = Error{
Stderr: stderrBuf.Bytes(),
WaitErr: err,
}
// fmt.Fprintf(os.Stderr, "ssh conn wait err: %#v\n", f.SSHCommandError.(Error))
} else {
f.SSHCommandError = nil
}
}()
@ -110,23 +125,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
}
type ForkedSSHReadWriteCloser struct {
RemoteStdin io.Writer
RemoteStdout io.Reader
Command *exec.Cmd
Cancel context.CancelFunc
exitWaitGroup *sync.WaitGroup
RemoteStdin io.Writer
RemoteStdout io.Reader
Command *exec.Cmd
Cancel context.CancelFunc
exitWaitGroup *sync.WaitGroup
SSHCommandError error
}
func (f ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
return f.RemoteStdout.Read(p)
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
if f.SSHCommandError != nil {
err = f.SSHCommandError
}
}
return
}
func (f ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
return f.RemoteStdin.Write(p)
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
if f.SSHCommandError != nil {
err = f.SSHCommandError
}
}
return
}
func (f ForkedSSHReadWriteCloser) Close() (err error) {
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
// TODO should check SSHCommandError?
f.Cancel()
f.exitWaitGroup.Wait()
return nil
return f.SSHCommandError
}