diff --git a/cmd/main.go b/cmd/main.go index 4591dd0..66d9473 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -29,9 +29,11 @@ func main() { defer func() { e := recover() - defaultLog.Printf("panic:\n%s\n\n", debug.Stack()) - defaultLog.Printf("error: %t %s", e, e) - os.Exit(1) + if e != nil { + defaultLog.Printf("panic:\n%s\n\n", debug.Stack()) + defaultLog.Printf("error: %t %s", e, e) + os.Exit(1) + } }() app := cli.NewApp() @@ -191,6 +193,32 @@ func doRun(c *cli.Context) error { return nil } +func closeRPCWithTimeout(log Logger, remote rpc.RPCRequester, timeout time.Duration, goodbye string) { + log.Printf("closing rpc connection") + + ch := make(chan error) + go func() { + ch <- remote.CloseRequest(rpc.CloseRequest{goodbye}) + }() + + var err error + select { + case <-time.After(timeout): + err = fmt.Errorf("timeout exceeded (%s)", timeout) + case closeRequestErr := <-ch: + err = closeRequestErr + } + + if err != nil { + log.Printf("error closing connection: %s", err) + err = remote.ForceClose() + if err != nil { + log.Printf("error force-closing connection: %s", err) + } + } + return +} + func doPull(pull Pull, c *cli.Context, log jobrun.Logger) (err error) { if lt, ok := pull.From.Transport.(LocalTransport); ok { @@ -208,7 +236,7 @@ func doPull(pull Pull, c *cli.Context, log jobrun.Logger) (err error) { return } - defer remote.Close() + defer closeRPCWithTimeout(log, remote, time.Second*10, "") fsr := rpc.FilesystemRequest{ Direction: rpc.DirectionPull,