zrepl/sshbytestream/ssh.go
2017-05-06 23:44:59 +02:00

165 lines
3.4 KiB
Go

package sshbytestream
import (
"bytes"
"context"
"fmt"
"io"
"os"
"os/exec"
"sync"
)
type Error struct {
Stderr []byte
WaitErr error
}
func (e Error) Error() string {
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
}
type SSHTransport struct {
Host string
User string
Port uint16
IdentityFile string
SSHCommand string
Options []string
}
var SSHCommand string = "ssh"
func Incoming() (wc io.ReadWriteCloser, err error) {
// derivce ReadWriteCloser from stdin & stdout
return IncomingReadWriteCloser{}, nil
}
type IncomingReadWriteCloser struct{}
func (f IncomingReadWriteCloser) Read(p []byte) (n int, err error) {
return os.Stdin.Read(p)
}
func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
return os.Stdout.Write(p)
}
func (f IncomingReadWriteCloser) Close() (err error) {
if err = os.Stdin.Close(); err != nil {
return
}
if err = os.Stdout.Close(); err != nil {
return
}
return
}
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
ctx, cancel := context.WithCancel(context.Background())
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
sshArgs = append(sshArgs,
"-p", fmt.Sprintf("%d", remote.Port),
"-q",
"-i", remote.IdentityFile,
"-o", "BatchMode=yes",
)
for _, option := range remote.Options {
sshArgs = append(sshArgs, "-o", option)
}
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
var sshCommand = SSHCommand
if len(remote.SSHCommand) > 0 {
sshCommand = SSHCommand
}
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
// Clear environment of cmd
cmd.Env = []string{}
var in io.WriteCloser
var out io.ReadCloser
if in, err = cmd.StdinPipe(); err != nil {
return
}
if out, err = cmd.StdoutPipe(); err != nil {
return
}
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
cmd.Stderr = stderrBuf
f := &ForkedSSHReadWriteCloser{
RemoteStdin: in,
RemoteStdout: out,
Cancel: cancel,
Command: cmd,
exitWaitGroup: &sync.WaitGroup{},
}
f.exitWaitGroup.Add(1)
if err = cmd.Start(); err != nil {
return
}
go func() {
defer f.exitWaitGroup.Done()
// stderr output is only relevant for errors if the exit code is non-zero
if err := cmd.Wait(); err != nil {
f.SSHCommandError = Error{
Stderr: stderrBuf.Bytes(),
WaitErr: err,
}
} else {
f.SSHCommandError = io.EOF
}
}()
return f, nil
}
type ForkedSSHReadWriteCloser struct {
RemoteStdin io.Writer
RemoteStdout io.Reader
Command *exec.Cmd
Cancel context.CancelFunc
exitWaitGroup *sync.WaitGroup
SSHCommandError error
}
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
err = f.SSHCommandError
}
return
}
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
if f.SSHCommandError != nil {
return 0, f.SSHCommandError
}
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
// the ssh command has exited, but we need to wait for post-portem to finish
f.exitWaitGroup.Wait()
err = f.SSHCommandError
}
return
}
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
// TODO should check SSHCommandError?
f.Cancel()
f.exitWaitGroup.Wait()
return f.SSHCommandError
}