mirror of
https://github.com/zrepl/zrepl.git
synced 2024-12-28 09:58:52 +01:00
165 lines
3.4 KiB
Go
165 lines
3.4 KiB
Go
package sshbytestream
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"sync"
|
|
)
|
|
|
|
type Error struct {
|
|
Stderr []byte
|
|
WaitErr error
|
|
}
|
|
|
|
func (e Error) Error() string {
|
|
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
|
|
}
|
|
|
|
type SSHTransport struct {
|
|
Host string
|
|
User string
|
|
Port uint16
|
|
IdentityFile string
|
|
SSHCommand string
|
|
Options []string
|
|
}
|
|
|
|
var SSHCommand string = "ssh"
|
|
|
|
func Incoming() (wc io.ReadWriteCloser, err error) {
|
|
// derivce ReadWriteCloser from stdin & stdout
|
|
return IncomingReadWriteCloser{}, nil
|
|
}
|
|
|
|
type IncomingReadWriteCloser struct{}
|
|
|
|
func (f IncomingReadWriteCloser) Read(p []byte) (n int, err error) {
|
|
return os.Stdin.Read(p)
|
|
}
|
|
|
|
func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
|
|
return os.Stdout.Write(p)
|
|
}
|
|
|
|
func (f IncomingReadWriteCloser) Close() (err error) {
|
|
if err = os.Stdin.Close(); err != nil {
|
|
return
|
|
}
|
|
if err = os.Stdout.Close(); err != nil {
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
|
|
sshArgs = append(sshArgs,
|
|
"-p", fmt.Sprintf("%d", remote.Port),
|
|
"-q",
|
|
"-i", remote.IdentityFile,
|
|
"-o", "BatchMode=yes",
|
|
)
|
|
for _, option := range remote.Options {
|
|
sshArgs = append(sshArgs, "-o", option)
|
|
}
|
|
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
|
|
|
|
var sshCommand = SSHCommand
|
|
if len(remote.SSHCommand) > 0 {
|
|
sshCommand = SSHCommand
|
|
}
|
|
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
|
|
|
|
// Clear environment of cmd
|
|
cmd.Env = []string{}
|
|
|
|
var in io.WriteCloser
|
|
var out io.ReadCloser
|
|
|
|
if in, err = cmd.StdinPipe(); err != nil {
|
|
return
|
|
}
|
|
if out, err = cmd.StdoutPipe(); err != nil {
|
|
return
|
|
}
|
|
|
|
stderrBuf := bytes.NewBuffer(make([]byte, 0, 1024))
|
|
cmd.Stderr = stderrBuf
|
|
|
|
f := &ForkedSSHReadWriteCloser{
|
|
RemoteStdin: in,
|
|
RemoteStdout: out,
|
|
Cancel: cancel,
|
|
Command: cmd,
|
|
exitWaitGroup: &sync.WaitGroup{},
|
|
}
|
|
|
|
f.exitWaitGroup.Add(1)
|
|
if err = cmd.Start(); err != nil {
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
defer f.exitWaitGroup.Done()
|
|
|
|
// stderr output is only relevant for errors if the exit code is non-zero
|
|
if err := cmd.Wait(); err != nil {
|
|
f.SSHCommandError = Error{
|
|
Stderr: stderrBuf.Bytes(),
|
|
WaitErr: err,
|
|
}
|
|
} else {
|
|
f.SSHCommandError = io.EOF
|
|
}
|
|
}()
|
|
|
|
return f, nil
|
|
}
|
|
|
|
type ForkedSSHReadWriteCloser struct {
|
|
RemoteStdin io.Writer
|
|
RemoteStdout io.Reader
|
|
Command *exec.Cmd
|
|
Cancel context.CancelFunc
|
|
exitWaitGroup *sync.WaitGroup
|
|
SSHCommandError error
|
|
}
|
|
|
|
func (f *ForkedSSHReadWriteCloser) Read(p []byte) (n int, err error) {
|
|
if f.SSHCommandError != nil {
|
|
return 0, f.SSHCommandError
|
|
}
|
|
if n, err = f.RemoteStdout.Read(p); err == io.EOF {
|
|
// the ssh command has exited, but we need to wait for post-portem to finish
|
|
f.exitWaitGroup.Wait()
|
|
err = f.SSHCommandError
|
|
}
|
|
return
|
|
}
|
|
|
|
func (f *ForkedSSHReadWriteCloser) Write(p []byte) (n int, err error) {
|
|
if f.SSHCommandError != nil {
|
|
return 0, f.SSHCommandError
|
|
}
|
|
if n, err = f.RemoteStdin.Write(p); err == io.EOF {
|
|
// the ssh command has exited, but we need to wait for post-portem to finish
|
|
f.exitWaitGroup.Wait()
|
|
err = f.SSHCommandError
|
|
}
|
|
return
|
|
}
|
|
|
|
func (f *ForkedSSHReadWriteCloser) Close() (err error) {
|
|
// TODO should check SSHCommandError?
|
|
f.Cancel()
|
|
f.exitWaitGroup.Wait()
|
|
return f.SSHCommandError
|
|
}
|