mirror of
https://github.com/zrepl/zrepl.git
synced 2025-02-16 10:29:54 +01:00
sshbytestream: fix semantics when ssh connection dies
in the Wait()ing goroutine, we create an instance of our own error containing WaitErr + stderr excerpt. Only if the remote command exits non-zero. io.EOF is the error we get as soon as the other end of the pipe (the ssh command) has died. => wait for it in the ReadWriter methods.
This commit is contained in:
parent
226935ddea
commit
ec4284f80c
@ -10,6 +10,15 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Stderr []byte
|
||||||
|
WaitErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Error) Error() string {
|
||||||
|
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
|
||||||
|
}
|
||||||
|
|
||||||
type SSHTransport struct {
|
type SSHTransport struct {
|
||||||
Host string
|
Host string
|
||||||
User string
|
User string
|
||||||
@ -37,8 +46,13 @@ func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f IncomingReadWriteCloser) Close() (err error) {
|
func (f IncomingReadWriteCloser) Close() (err error) {
|
||||||
os.Exit(0)
|
if err = os.Stdin.Close(); err != nil {
|
||||||
return nil
|
return
|
||||||
|
}
|
||||||
|
if err = os.Stdout.Close(); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||||
@ -68,20 +82,18 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
|||||||
|
|
||||||
var in io.WriteCloser
|
var in io.WriteCloser
|
||||||
var out io.ReadCloser
|
var out io.ReadCloser
|
||||||
var stderr io.Reader
|
|
||||||
|
|
||||||
if in, err = cmd.StdinPipe(); err != nil {
|
if in, err = cmd.StdinPipe(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if out, err = cmd.StdoutPipe(); err != nil {
|
if out, err = cmd.StdoutPipe(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if stderr, err = cmd.StderrPipe(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
f := ForkedSSHReadWriteCloser{
|
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
|
||||||
|
cmd.Stderr = stderrBuf
|
||||||
|
|
||||||
|
f := &ForkedSSHReadWriteCloser{
|
||||||
RemoteStdin: in,
|
RemoteStdin: in,
|
||||||
RemoteStdout: out,
|
RemoteStdout: out,
|
||||||
Cancel: cancel,
|
Cancel: cancel,
|
||||||
@ -96,13 +108,16 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer f.exitWaitGroup.Done()
|
defer f.exitWaitGroup.Done()
|
||||||
var b bytes.Buffer
|
|
||||||
if _, err := io.Copy(&b, stderr); err != nil {
|
// stderr output is only relevant for errors if the exit code is non-zero
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if err := cmd.Wait(); err != nil {
|
if err := cmd.Wait(); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b)
|
f.SSHCommandError = Error{
|
||||||
//panic(err) TODO
|
Stderr: stderrBuf.Bytes(),
|
||||||
|
WaitErr: err,
|
||||||
|
}
|
||||||
|
// fmt.Fprintf(os.Stderr, "ssh conn wait err: %#v\n", f.SSHCommandError.(Error))
|
||||||
|
} else {
|
||||||
|
f.SSHCommandError = nil
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -110,23 +125,45 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ForkedSSHReadWriteCloser struct {
|
type ForkedSSHReadWriteCloser struct {
|
||||||
RemoteStdin io.Writer
|
RemoteStdin io.Writer
|
||||||
RemoteStdout io.Reader
|
RemoteStdout io.Reader
|
||||||
Command *exec.Cmd
|
Command *exec.Cmd
|
||||||
Cancel context.CancelFunc
|
Cancel context.CancelFunc
|
||||||
exitWaitGroup *sync.WaitGroup
|
exitWaitGroup *sync.WaitGroup
|
||||||
|
SSHCommandError error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
|
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||||
return f.RemoteStdout.Read(p)
|
if f.SSHCommandError != nil {
|
||||||
|
return 0, f.SSHCommandError
|
||||||
|
}
|
||||||
|
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
|
||||||
|
// the ssh command has exited, but we need to wait for post-portem to finish
|
||||||
|
f.exitWaitGroup.Wait()
|
||||||
|
if f.SSHCommandError != nil {
|
||||||
|
err = f.SSHCommandError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
|
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||||
return f.RemoteStdin.Write(p)
|
if f.SSHCommandError != nil {
|
||||||
|
return 0, f.SSHCommandError
|
||||||
|
}
|
||||||
|
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
|
||||||
|
// the ssh command has exited, but we need to wait for post-portem to finish
|
||||||
|
f.exitWaitGroup.Wait()
|
||||||
|
if f.SSHCommandError != nil {
|
||||||
|
err = f.SSHCommandError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f ForkedSSHReadWriteCloser) Close() (err error) {
|
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
|
||||||
|
// TODO should check SSHCommandError?
|
||||||
f.Cancel()
|
f.Cancel()
|
||||||
f.exitWaitGroup.Wait()
|
f.exitWaitGroup.Wait()
|
||||||
return nil
|
return f.SSHCommandError
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user