diff --git a/util/iocommand.go b/util/iocommand.go index 1113f74..fe5f9a3 100644 --- a/util/iocommand.go +++ b/util/iocommand.go @@ -4,15 +4,18 @@ import ( "bytes" "context" "fmt" + "github.com/zrepl/zrepl/util/envconst" "io" "os" "os/exec" "syscall" + "time" ) // An IOCommand exposes a forked process's std(in|out|err) through the io.ReadWriteCloser interface. type IOCommand struct { Cmd *exec.Cmd + kill context.CancelFunc Stdin io.WriteCloser Stdout io.ReadCloser StderrBuf *bytes.Buffer @@ -52,6 +55,7 @@ func NewIOCommand(ctx context.Context, command string, args []string, stderrBufS c = &IOCommand{} + ctx, c.kill = context.WithCancel(ctx) c.Cmd = exec.CommandContext(ctx, command, args...) if c.Stdout, err = c.Cmd.StdoutPipe(); err != nil { @@ -81,14 +85,24 @@ func (c *IOCommand) Start() (err error) { func (c *IOCommand) Read(buf []byte) (n int, err error) { n, err = c.Stdout.Read(buf) if err == io.EOF { - if waitErr := c.doWait(); waitErr != nil { + if waitErr := c.doWait(context.Background()); waitErr != nil { err = waitErr } } return } -func (c *IOCommand) doWait() (err error) { +func (c *IOCommand) doWait(ctx context.Context) (err error) { + go func() { + dl, ok := ctx.Deadline() + if !ok { + return + } + time.Sleep(dl.Sub(time.Now())) + c.kill() + c.Stdout.Close() + c.Stdin.Close() + }() waitErr := c.Cmd.Wait() var wasUs bool = false var waitStatus syscall.WaitStatus @@ -133,10 +147,9 @@ func (c *IOCommand) Close() (err error) { if c.Cmd.ProcessState == nil { // racy... err = syscall.Kill(c.Cmd.Process.Pid, syscall.SIGTERM) - if err != nil { - return - } - return c.doWait() + ctx, cancel := context.WithTimeout(context.Background(), envconst.Duration("IOCOMMAND_TIMEOUT", 10*time.Second)) + defer cancel() + return c.doWait(ctx) } else { return c.ExitResult.Error }