sshbytestream: IdentityFile and custom SSHCommand.

This commit is contained in:
Christian Schwarz 2017-04-30 16:11:33 +02:00
parent b9361f275a
commit 2e6dc26993
2 changed files with 57 additions and 29 deletions

View File

@ -14,6 +14,8 @@ import (
"strings"
)
const LOCAL_TRANSPORT_IDENTITY string = "local"
type Pool struct {
Name string
Transport Transport
@ -23,14 +25,15 @@ type Transport interface {
Connect() (rpc.RPCRequester, error)
}
type LocalTransport struct {
Pool string
Handler rpc.RPCHandler
}
type SSHTransport struct {
ZreplIdentity string
Host string
User string
Port uint16
TransportOpenCommand []string
IdentityFile string `mapstructure:"identity_file"`
TransportOpenCommand []string `mapstructure:"transport_open_command"`
SSHCommand string `mapstructure:"ssh_command"`
Options []string
}
@ -138,13 +141,6 @@ func parseTransport(it map[string]interface{}) (t Transport, err error) {
return nil, err
}
return t, nil
case "local":
t := LocalTransport{}
if err = mapstructure.Decode(val, &t); err != nil {
err = errors.New(fmt.Sprintf("could not parse local transport: %s", err))
return nil, err
}
return t, nil
default:
return nil, errors.New(fmt.Sprintf("unknown transport type '%s'\n", key))
}
@ -208,9 +204,17 @@ func parsePulls(v interface{}, pl poolLookup) (p []Pull, err error) {
for i, e := range asList {
var fromPool *Pool
if e.From == LOCAL_TRANSPORT_IDENTITY {
fromPool = &Pool{
Name: "local",
Transport: LocalTransport{},
}
} else {
if fromPool, err = pl(e.From); err != nil {
return
}
}
pull := Pull{
From: fromPool,
}
@ -319,7 +323,9 @@ func parseComboMapping(m map[string]string) (c zfs.ComboMapping, err error) {
func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) {
var stream io.ReadWriteCloser
var rpcTransport sshbytestream.SSHTransport
copier.Copy(rpcTransport, t)
if err = copier.Copy(&rpcTransport, t); err != nil {
return
}
if stream, err = sshbytestream.Outgoing(rpcTransport); err != nil {
return
}
@ -327,6 +333,12 @@ func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) {
}
func (t LocalTransport) Connect() (r rpc.RPCRequester, err error) {
// TODO ugly hidden global variable reference
return rpc.ConnectLocalRPC(handler), nil
if t.Handler == nil {
panic("local transport with uninitialized handler")
}
return rpc.ConnectLocalRPC(t.Handler), nil
}
func (t *LocalTransport) SetHandler(handler rpc.RPCHandler) {
t.Handler = handler
}

View File

@ -14,7 +14,8 @@ type SSHTransport struct {
Host string
User string
Port uint16
TransportOpenCommand []string
IdentityFile string
SSHCommand string
Options []string
}
@ -44,21 +45,30 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
ctx, cancel := context.WithCancel(context.Background())
sshArgs := make([]string, 0, 2*len(remote.Options)+len(remote.TransportOpenCommand)+4)
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
sshArgs = append(sshArgs,
"-p", fmt.Sprintf("%d", remote.Port),
"-q",
"-i", remote.IdentityFile,
"-o", "BatchMode=yes",
)
for _, option := range remote.Options {
sshArgs = append(sshArgs, "-o", option)
}
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
sshArgs = append(sshArgs, remote.TransportOpenCommand...)
cmd := exec.CommandContext(ctx, SSHCommand, sshArgs...)
var sshCommand = SSHCommand
if len(remote.SSHCommand) > 0 {
sshCommand = SSHCommand
}
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
// Clear environment of cmd
cmd.Env = []string{}
var in io.WriteCloser
var out io.ReadCloser
var stderr io.Reader
if in, err = cmd.StdinPipe(); err != nil {
return
@ -67,6 +77,9 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
if out, err = cmd.StdoutPipe(); err != nil {
return
}
if stderr, err = cmd.StderrPipe(); err != nil {
return
}
f := ForkedSSHReadWriteCloser{
RemoteStdin: in,
@ -77,16 +90,19 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
}
f.exitWaitGroup.Add(1)
if err = cmd.Start(); err != nil {
return
}
go func() {
defer f.exitWaitGroup.Done()
stderr, _ := cmd.StderrPipe()
var b bytes.Buffer
if err := cmd.Run(); err != nil {
io.Copy(&b, stderr)
fmt.Println(b.String())
fmt.Printf("%v\n", cmd.ProcessState)
//panic(err)
if _, err := io.Copy(&b, stderr); err != nil {
panic(err)
}
if err := cmd.Wait(); err != nil {
fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b)
//panic(err) TODO
}
}()