sshbytestream & IOCommand: fix handling of dead child process

SSH catches SIGTERM, tears down its connection, then exits with
platform-specific exit code.
This commit is contained in:
Christian Schwarz 2017-08-09 21:01:06 +02:00
parent e2bbd4287e
commit ca1a482e9e
2 changed files with 49 additions and 19 deletions

View File

@ -2,9 +2,12 @@ package sshbytestream
import (
"fmt"
"github.com/zrepl/zrepl/util"
"io"
"os"
"runtime"
"syscall"
"github.com/zrepl/zrepl/util"
)
type Error struct {
@ -53,7 +56,11 @@ func (f IncomingReadWriteCloser) Close() (err error) {
return
}
func Outgoing(remote SSHTransport) (c *util.IOCommand, err error) {
type OutgoingSSHByteStream struct {
c *util.IOCommand
}
func Outgoing(remote SSHTransport) (s OutgoingSSHByteStream, err error) {
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
sshArgs = append(sshArgs,
@ -73,13 +80,44 @@ func Outgoing(remote SSHTransport) (c *util.IOCommand, err error) {
sshCommand = SSHCommand
}
if c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil {
if s.c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil {
return
}
// Clear environment of cmd, ssh shall not rely on SSH_AUTH_SOCK, etc.
c.Cmd.Env = []string{}
s.c.Cmd.Env = []string{}
err = s.c.Start()
return
}
func (s OutgoingSSHByteStream) Read(p []byte) (n int, err error) {
return s.c.Read(p)
}
func (s OutgoingSSHByteStream) Write(p []byte) (n int, err error) {
return s.c.Write(p)
}
func (s OutgoingSSHByteStream) Close() (err error) {
err = s.c.Close()
if err == nil || s.c.ExitResult == nil {
return
}
// SSH catches SIGTERM and has different exit codes on different platforms
ws := s.c.ExitResult.WaitStatus
switch runtime.GOOS {
case "linux":
if ws.ExitStatus() == 128+int(syscall.SIGTERM) { // OpenSSH_7.5p1, OpenSSL 1.1.0f 25 May 2017 Arch Linux
err = nil
}
case "freebsd": // OpenSSH_7.2p2, OpenSSL 1.0.2k-freebsd 26 Jan 2017
if ws.ExitStatus() == 255 {
err = nil
}
default: // TODO
}
err = c.Start()
return
}

View File

@ -3,7 +3,6 @@ package util
import (
"bytes"
"fmt"
"golang.org/x/sys/unix"
"io"
"os/exec"
"syscall"
@ -15,7 +14,7 @@ type IOCommand struct {
Stdin io.Writer
Stdout io.Reader
StderrBuf *bytes.Buffer
ExitResult IOCommandExitResult
ExitResult *IOCommandExitResult
}
const IOCommandStderrBufSize = 1024
@ -90,24 +89,17 @@ func (c *IOCommand) Read(buf []byte) (n int, err error) {
func (c *IOCommand) doWait() (err error) {
waitErr := c.Cmd.Wait()
waitStatus := c.Cmd.ProcessState.Sys().(syscall.WaitStatus) // Fail hard if we're not on UNIX
if waitErr != nil {
// https://support.ssh.com/manuals/client-user/44/ssh2_Return_Values.html
// If ssh is terminated via signal, its exit status is 128 + signal number
if waitStatus.ExitStatus() == 128+int(syscall.SIGTERM) {
// discard wait err, we assume this is due to earlier c.Close()
goto out
}
wasUs := waitStatus.Signaled() && waitStatus.Signal() == syscall.SIGTERM // in Close()
if waitErr != nil && !wasUs {
err = IOCommandError{
WaitErr: waitErr,
Stderr: c.StderrBuf.Bytes(),
}
}
out:
c.ExitResult = IOCommandExitResult{
Error: err,
c.ExitResult = &IOCommandExitResult{
Error: err, // is still empty if waitErr was due to signalling
WaitStatus: waitStatus,
}
return
@ -124,7 +116,7 @@ func (c *IOCommand) Write(buf []byte) (n int, err error) {
func (c *IOCommand) Close() (err error) {
if c.Cmd.ProcessState == nil {
// racy...
err = unix.Kill(c.Cmd.Process.Pid, syscall.SIGTERM)
err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM)
if err != nil {
return
}