util/iocommand: timeout kill on close + other hardening

This commit is contained in:
Christian Schwarz 2018-11-06 23:37:25 +01:00
parent 1aae7b222f
commit 7a75a4d384

View File

@ -4,15 +4,18 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/zrepl/zrepl/util/envconst"
"io" "io"
"os" "os"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
) )
// An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface. // An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface.
type IOCommand struct { type IOCommand struct {
Cmd *exec.Cmd Cmd *exec.Cmd
kill context.CancelFunc
Stdin io.WriteCloser Stdin io.WriteCloser
Stdout io.ReadCloser Stdout io.ReadCloser
StderrBuf *bytes.Buffer StderrBuf *bytes.Buffer
@ -52,6 +55,7 @@ func NewIOCommand(ctx context.Context, command string, args []string, stderrBufS
c = &IOCommand{} c = &IOCommand{}
ctx, c.kill = context.WithCancel(ctx)
c.Cmd = exec.CommandContext(ctx, command, args...) c.Cmd = exec.CommandContext(ctx, command, args...)
if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil { if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil {
@ -81,14 +85,24 @@ func (c *IOCommand) Start() (err error) {
func (c *IOCommand) Read(buf []byte) (n int, err error) { func (c *IOCommand) Read(buf []byte) (n int, err error) {
n, err = c.Stdout.Read(buf) n, err = c.Stdout.Read(buf)
if err == io.EOF { if err == io.EOF {
if waitErr := c.doWait(); waitErr != nil { if waitErr := c.doWait(context.Background()); waitErr != nil {
err = waitErr err = waitErr
} }
} }
return return
} }
func (c *IOCommand) doWait() (err error) { func (c *IOCommand) doWait(ctx context.Context) (err error) {
go func() {
dl, ok := ctx.Deadline()
if !ok {
return
}
time.Sleep(dl.Sub(time.Now()))
c.kill()
c.Stdout.Close()
c.Stdin.Close()
}()
waitErr := c.Cmd.Wait() waitErr := c.Cmd.Wait()
var wasUs bool = false var wasUs bool = false
var waitStatus syscall.WaitStatus var waitStatus syscall.WaitStatus
@ -133,10 +147,9 @@ func (c *IOCommand) Close() (err error) {
if c.Cmd.ProcessState == nil { if c.Cmd.ProcessState == nil {
// racy... // racy...
err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM) err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM)
if err != nil { ctx, cancel := context.WithTimeout(context.Background(), envconst.Duration("IOCOMMAND_TIMEOUT", 10*time.Second))
return defer cancel()
} return c.doWait(ctx)
return c.doWait()
} else { } else {
return c.ExitResult.Error return c.ExitResult.Error
} }