mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-22 00:13:52 +01:00
sshbytestream: IdentityFile and custom SSHCommand.
This commit is contained in:
parent
b9361f275a
commit
2e6dc26993
@ -14,6 +14,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const LOCAL_TRANSPORT_IDENTITY string = "local"
|
||||
|
||||
type Pool struct {
|
||||
Name string
|
||||
Transport Transport
|
||||
@ -23,14 +25,15 @@ type Transport interface {
|
||||
Connect() (rpc.RPCRequester, error)
|
||||
}
|
||||
type LocalTransport struct {
|
||||
Pool string
|
||||
Handler rpc.RPCHandler
|
||||
}
|
||||
type SSHTransport struct {
|
||||
ZreplIdentity string
|
||||
Host string
|
||||
User string
|
||||
Port uint16
|
||||
TransportOpenCommand []string
|
||||
IdentityFile string `mapstructure:"identity_file"`
|
||||
TransportOpenCommand []string `mapstructure:"transport_open_command"`
|
||||
SSHCommand string `mapstructure:"ssh_command"`
|
||||
Options []string
|
||||
}
|
||||
|
||||
@ -138,13 +141,6 @@ func parseTransport(it map[string]interface{}) (t Transport, err error) {
|
||||
return nil, err
|
||||
}
|
||||
return t, nil
|
||||
case "local":
|
||||
t := LocalTransport{}
|
||||
if err = mapstructure.Decode(val, &t); err != nil {
|
||||
err = errors.New(fmt.Sprintf("could not parse local transport: %s", err))
|
||||
return nil, err
|
||||
}
|
||||
return t, nil
|
||||
default:
|
||||
return nil, errors.New(fmt.Sprintf("unknown transport type '%s'\n", key))
|
||||
}
|
||||
@ -208,9 +204,17 @@ func parsePulls(v interface{}, pl poolLookup) (p []Pull, err error) {
|
||||
for i, e := range asList {
|
||||
|
||||
var fromPool *Pool
|
||||
|
||||
if e.From == LOCAL_TRANSPORT_IDENTITY {
|
||||
fromPool = &Pool{
|
||||
Name: "local",
|
||||
Transport: LocalTransport{},
|
||||
}
|
||||
} else {
|
||||
if fromPool, err = pl(e.From); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
pull := Pull{
|
||||
From: fromPool,
|
||||
}
|
||||
@ -319,7 +323,9 @@ func parseComboMapping(m map[string]string) (c zfs.ComboMapping, err error) {
|
||||
func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) {
|
||||
var stream io.ReadWriteCloser
|
||||
var rpcTransport sshbytestream.SSHTransport
|
||||
copier.Copy(rpcTransport, t)
|
||||
if err = copier.Copy(&rpcTransport, t); err != nil {
|
||||
return
|
||||
}
|
||||
if stream, err = sshbytestream.Outgoing(rpcTransport); err != nil {
|
||||
return
|
||||
}
|
||||
@ -327,6 +333,12 @@ func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) {
|
||||
}
|
||||
|
||||
func (t LocalTransport) Connect() (r rpc.RPCRequester, err error) {
|
||||
// TODO ugly hidden global variable reference
|
||||
return rpc.ConnectLocalRPC(handler), nil
|
||||
if t.Handler == nil {
|
||||
panic("local transport with uninitialized handler")
|
||||
}
|
||||
return rpc.ConnectLocalRPC(t.Handler), nil
|
||||
}
|
||||
|
||||
func (t *LocalTransport) SetHandler(handler rpc.RPCHandler) {
|
||||
t.Handler = handler
|
||||
}
|
||||
|
@ -14,7 +14,8 @@ type SSHTransport struct {
|
||||
Host string
|
||||
User string
|
||||
Port uint16
|
||||
TransportOpenCommand []string
|
||||
IdentityFile string
|
||||
SSHCommand string
|
||||
Options []string
|
||||
}
|
||||
|
||||
@ -44,21 +45,30 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
sshArgs := make([]string, 0, 2*len(remote.Options)+len(remote.TransportOpenCommand)+4)
|
||||
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
|
||||
sshArgs = append(sshArgs,
|
||||
"-p", fmt.Sprintf("%d", remote.Port),
|
||||
"-q",
|
||||
"-i", remote.IdentityFile,
|
||||
"-o", "BatchMode=yes",
|
||||
)
|
||||
for _, option := range remote.Options {
|
||||
sshArgs = append(sshArgs, "-o", option)
|
||||
}
|
||||
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
|
||||
sshArgs = append(sshArgs, remote.TransportOpenCommand...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, SSHCommand, sshArgs...)
|
||||
var sshCommand = SSHCommand
|
||||
if len(remote.SSHCommand) > 0 {
|
||||
sshCommand = SSHCommand
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
|
||||
|
||||
// Clear environment of cmd
|
||||
cmd.Env = []string{}
|
||||
|
||||
var in io.WriteCloser
|
||||
var out io.ReadCloser
|
||||
var stderr io.Reader
|
||||
|
||||
if in, err = cmd.StdinPipe(); err != nil {
|
||||
return
|
||||
@ -67,6 +77,9 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
if out, err = cmd.StdoutPipe(); err != nil {
|
||||
return
|
||||
}
|
||||
if stderr, err = cmd.StderrPipe(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
f := ForkedSSHReadWriteCloser{
|
||||
RemoteStdin: in,
|
||||
@ -77,16 +90,19 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
|
||||
}
|
||||
|
||||
f.exitWaitGroup.Add(1)
|
||||
if err = cmd.Start(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer f.exitWaitGroup.Done()
|
||||
stderr, _ := cmd.StderrPipe()
|
||||
var b bytes.Buffer
|
||||
if err := cmd.Run(); err != nil {
|
||||
io.Copy(&b, stderr)
|
||||
fmt.Println(b.String())
|
||||
fmt.Printf("%v\n", cmd.ProcessState)
|
||||
//panic(err)
|
||||
if _, err := io.Copy(&b, stderr); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := cmd.Wait(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b)
|
||||
//panic(err) TODO
|
||||
}
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user