diff --git a/sshbytestream/ssh.go b/sshbytestream/ssh.go index 3e5b328..ac18c35 100644 --- a/sshbytestream/ssh.go +++ b/sshbytestream/ssh.go @@ -10,6 +10,15 @@ import ( "sync" ) +type Error struct { + Stderr []byte + WaitErr error +} + +func (e Error) Error() string { + return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr) +} + type SSHTransport struct { Host string User string @@ -37,8 +46,13 @@ func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) { } func (f IncomingReadWriteCloser) Close() (err error) { - os.Exit(0) - return nil + if err = os.Stdin.Close(); err != nil { + return + } + if err = os.Stdout.Close(); err != nil { + return + } + return } func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { @@ -68,20 +82,18 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { var in io.WriteCloser var out io.ReadCloser - var stderr io.Reader if in, err = cmd.StdinPipe(); err != nil { return } - if out, err = cmd.StdoutPipe(); err != nil { return } - if stderr, err = cmd.StderrPipe(); err != nil { - return - } - f := ForkedSSHReadWriteCloser{ + stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024)) + cmd.Stderr = stderrBuf + + f := &ForkedSSHReadWriteCloser{ RemoteStdin: in, RemoteStdout: out, Cancel: cancel, @@ -96,13 +108,16 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { go func() { defer f.exitWaitGroup.Done() - var b bytes.Buffer - if _, err := io.Copy(&b, stderr); err != nil { - panic(err) - } + + // stderr output is only relevant for errors if the exit code is non-zero if err := cmd.Wait(); err != nil { - fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b) - //panic(err) TODO + f.SSHCommandError = Error{ + Stderr: stderrBuf.Bytes(), + WaitErr: err, + } + // fmt.Fprintf(os.Stderr, "ssh conn wait err: %#v\n", f.SSHCommandError.(Error)) + } else { + f.SSHCommandError = nil } }() @@ -110,23 +125,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { } type ForkedSSHReadWriteCloser struct { - RemoteStdin io.Writer - RemoteStdout io.Reader - Command *exec.Cmd - Cancel context.CancelFunc - exitWaitGroup *sync.WaitGroup + RemoteStdin io.Writer + RemoteStdout io.Reader + Command *exec.Cmd + Cancel context.CancelFunc + exitWaitGroup *sync.WaitGroup + SSHCommandError error } -func (f ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) { - return f.RemoteStdout.Read(p) +func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) { + if f.SSHCommandError != nil { + return 0, f.SSHCommandError + } + if n, err = f.RemoteStdout.Read(p); err == io.EOF { + // the ssh command has exited, but we need to wait for post-portem to finish + f.exitWaitGroup.Wait() + if f.SSHCommandError != nil { + err = f.SSHCommandError + } + } + return } -func (f ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) { - return f.RemoteStdin.Write(p) +func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) { + if f.SSHCommandError != nil { + return 0, f.SSHCommandError + } + if n, err = f.RemoteStdin.Write(p); err == io.EOF { + // the ssh command has exited, but we need to wait for post-portem to finish + f.exitWaitGroup.Wait() + if f.SSHCommandError != nil { + err = f.SSHCommandError + } + } + return } -func (f ForkedSSHReadWriteCloser) Close() (err error) { +func (f *ForkedSSHReadWriteCloser) Close() (err error) { + // TODO should check SSHCommandError? f.Cancel() f.exitWaitGroup.Wait() - return nil + return f.SSHCommandError }