sshbytestream: IdentityFile and custom SSHCommand.

This commit is contained in:
Christian Schwarz 2017-04-30 16:11:33 +02:00
parent b9361f275a
commit 2e6dc26993
2 changed files with 57 additions and 29 deletions

View File

@ -14,6 +14,8 @@ import (
"strings" "strings"
) )
const LOCAL_TRANSPORT_IDENTITY string = "local"
type Pool struct { type Pool struct {
Name string Name string
Transport Transport Transport Transport
@ -23,14 +25,15 @@ type Transport interface {
Connect() (rpc.RPCRequester, error) Connect() (rpc.RPCRequester, error)
} }
type LocalTransport struct { type LocalTransport struct {
Pool string Handler rpc.RPCHandler
} }
type SSHTransport struct { type SSHTransport struct {
ZreplIdentity string
Host string Host string
User string User string
Port uint16 Port uint16
TransportOpenCommand []string IdentityFile string `mapstructure:"identity_file"`
TransportOpenCommand []string `mapstructure:"transport_open_command"`
SSHCommand string `mapstructure:"ssh_command"`
Options []string Options []string
} }
@ -138,13 +141,6 @@ func parseTransport(it map[string]interface{}) (t Transport, err error) {
return nil, err return nil, err
} }
return t, nil return t, nil
case "local":
t := LocalTransport{}
if err = mapstructure.Decode(val, &t); err != nil {
err = errors.New(fmt.Sprintf("could not parse local transport: %s", err))
return nil, err
}
return t, nil
default: default:
return nil, errors.New(fmt.Sprintf("unknown transport type '%s'\n", key)) return nil, errors.New(fmt.Sprintf("unknown transport type '%s'\n", key))
} }
@ -208,8 +204,16 @@ func parsePulls(v interface{}, pl poolLookup) (p []Pull, err error) {
for i, e := range asList { for i, e := range asList {
var fromPool *Pool var fromPool *Pool
if fromPool, err = pl(e.From); err != nil {
return if e.From == LOCAL_TRANSPORT_IDENTITY {
fromPool = &Pool{
Name: "local",
Transport: LocalTransport{},
}
} else {
if fromPool, err = pl(e.From); err != nil {
return
}
} }
pull := Pull{ pull := Pull{
From: fromPool, From: fromPool,
@ -319,7 +323,9 @@ func parseComboMapping(m map[string]string) (c zfs.ComboMapping, err error) {
func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) { func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) {
var stream io.ReadWriteCloser var stream io.ReadWriteCloser
var rpcTransport sshbytestream.SSHTransport var rpcTransport sshbytestream.SSHTransport
copier.Copy(rpcTransport, t) if err = copier.Copy(&rpcTransport, t); err != nil {
return
}
if stream, err = sshbytestream.Outgoing(rpcTransport); err != nil { if stream, err = sshbytestream.Outgoing(rpcTransport); err != nil {
return return
} }
@ -327,6 +333,12 @@ func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) {
} }
func (t LocalTransport) Connect() (r rpc.RPCRequester, err error) { func (t LocalTransport) Connect() (r rpc.RPCRequester, err error) {
// TODO ugly hidden global variable reference if t.Handler == nil {
return rpc.ConnectLocalRPC(handler), nil panic("local transport with uninitialized handler")
}
return rpc.ConnectLocalRPC(t.Handler), nil
}
func (t *LocalTransport) SetHandler(handler rpc.RPCHandler) {
t.Handler = handler
} }

View File

@ -11,11 +11,12 @@ import (
) )
type SSHTransport struct { type SSHTransport struct {
Host string Host string
User string User string
Port uint16 Port uint16
TransportOpenCommand []string IdentityFile string
Options []string SSHCommand string
Options []string
} }
var SSHCommand string = "ssh" var SSHCommand string = "ssh"
@ -44,21 +45,30 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
sshArgs := make([]string, 0, 2*len(remote.Options)+len(remote.TransportOpenCommand)+4) sshArgs := make([]string, 0, 2*len(remote.Options)+4)
sshArgs = append(sshArgs, sshArgs = append(sshArgs,
"-p", fmt.Sprintf("%d", remote.Port), "-p", fmt.Sprintf("%d", remote.Port),
"-q",
"-i", remote.IdentityFile,
"-o", "BatchMode=yes", "-o", "BatchMode=yes",
) )
for _, option := range remote.Options { for _, option := range remote.Options {
sshArgs = append(sshArgs, "-o", option) sshArgs = append(sshArgs, "-o", option)
} }
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host)) sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
sshArgs = append(sshArgs, remote.TransportOpenCommand...)
cmd := exec.CommandContext(ctx, SSHCommand, sshArgs...) var sshCommand = SSHCommand
if len(remote.SSHCommand) > 0 {
sshCommand = SSHCommand
}
cmd := exec.CommandContext(ctx, sshCommand, sshArgs...)
// Clear environment of cmd
cmd.Env = []string{}
var in io.WriteCloser var in io.WriteCloser
var out io.ReadCloser var out io.ReadCloser
var stderr io.Reader
if in, err = cmd.StdinPipe(); err != nil { if in, err = cmd.StdinPipe(); err != nil {
return return
@ -67,6 +77,9 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
if out, err = cmd.StdoutPipe(); err != nil { if out, err = cmd.StdoutPipe(); err != nil {
return return
} }
if stderr, err = cmd.StderrPipe(); err != nil {
return
}
f := ForkedSSHReadWriteCloser{ f := ForkedSSHReadWriteCloser{
RemoteStdin: in, RemoteStdin: in,
@ -77,16 +90,19 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) {
} }
f.exitWaitGroup.Add(1) f.exitWaitGroup.Add(1)
if err = cmd.Start(); err != nil {
return
}
go func() { go func() {
defer f.exitWaitGroup.Done() defer f.exitWaitGroup.Done()
stderr, _ := cmd.StderrPipe()
var b bytes.Buffer var b bytes.Buffer
if err := cmd.Run(); err != nil { if _, err := io.Copy(&b, stderr); err != nil {
io.Copy(&b, stderr) panic(err)
fmt.Println(b.String()) }
fmt.Printf("%v\n", cmd.ProcessState) if err := cmd.Wait(); err != nil {
//panic(err) fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b)
//panic(err) TODO
} }
}() }()