diff --git a/cmd/config.go b/cmd/config.go index 1720eda..163cf1f 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -14,6 +14,8 @@ import ( "strings" ) +const LOCAL_TRANSPORT_IDENTITY string = "local" + type Pool struct { Name string Transport Transport @@ -23,14 +25,15 @@ type Transport interface { Connect() (rpc.RPCRequester, error) } type LocalTransport struct { - Pool string + Handler rpc.RPCHandler } type SSHTransport struct { - ZreplIdentity string Host string User string Port uint16 - TransportOpenCommand []string + IdentityFile string `mapstructure:"identity_file"` + TransportOpenCommand []string `mapstructure:"transport_open_command"` + SSHCommand string `mapstructure:"ssh_command"` Options []string } @@ -138,13 +141,6 @@ func parseTransport(it map[string]interface{}) (t Transport, err error) { return nil, err } return t, nil - case "local": - t := LocalTransport{} - if err = mapstructure.Decode(val, &t); err != nil { - err = errors.New(fmt.Sprintf("could not parse local transport: %s", err)) - return nil, err - } - return t, nil default: return nil, errors.New(fmt.Sprintf("unknown transport type '%s'\n", key)) } @@ -208,8 +204,16 @@ func parsePulls(v interface{}, pl poolLookup) (p []Pull, err error) { for i, e := range asList { var fromPool *Pool - if fromPool, err = pl(e.From); err != nil { - return + + if e.From == LOCAL_TRANSPORT_IDENTITY { + fromPool = &Pool{ + Name: "local", + Transport: LocalTransport{}, + } + } else { + if fromPool, err = pl(e.From); err != nil { + return + } } pull := Pull{ From: fromPool, @@ -319,7 +323,9 @@ func parseComboMapping(m map[string]string) (c zfs.ComboMapping, err error) { func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) { var stream io.ReadWriteCloser var rpcTransport sshbytestream.SSHTransport - copier.Copy(rpcTransport, t) + if err = copier.Copy(&rpcTransport, t); err != nil { + return + } if stream, err = sshbytestream.Outgoing(rpcTransport); err != nil { return } @@ -327,6 +333,12 @@ func (t SSHTransport) Connect() (r rpc.RPCRequester, err error) { } func (t LocalTransport) Connect() (r rpc.RPCRequester, err error) { - // TODO ugly hidden global variable reference - return rpc.ConnectLocalRPC(handler), nil + if t.Handler == nil { + panic("local transport with uninitialized handler") + } + return rpc.ConnectLocalRPC(t.Handler), nil +} + +func (t *LocalTransport) SetHandler(handler rpc.RPCHandler) { + t.Handler = handler } diff --git a/sshbytestream/ssh.go b/sshbytestream/ssh.go index bf9b239..3e5b328 100644 --- a/sshbytestream/ssh.go +++ b/sshbytestream/ssh.go @@ -11,11 +11,12 @@ import ( ) type SSHTransport struct { - Host string - User string - Port uint16 - TransportOpenCommand []string - Options []string + Host string + User string + Port uint16 + IdentityFile string + SSHCommand string + Options []string } var SSHCommand string = "ssh" @@ -44,21 +45,30 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { ctx, cancel := context.WithCancel(context.Background()) - sshArgs := make([]string, 0, 2*len(remote.Options)+len(remote.TransportOpenCommand)+4) + sshArgs := make([]string, 0, 2*len(remote.Options)+4) sshArgs = append(sshArgs, "-p", fmt.Sprintf("%d", remote.Port), + "-q", + "-i", remote.IdentityFile, "-o", "BatchMode=yes", ) for _, option := range remote.Options { sshArgs = append(sshArgs, "-o", option) } sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host)) - sshArgs = append(sshArgs, remote.TransportOpenCommand...) - cmd := exec.CommandContext(ctx, SSHCommand, sshArgs...) + var sshCommand = SSHCommand + if len(remote.SSHCommand) > 0 { + sshCommand = SSHCommand + } + cmd := exec.CommandContext(ctx, sshCommand, sshArgs...) + + // Clear environment of cmd + cmd.Env = []string{} var in io.WriteCloser var out io.ReadCloser + var stderr io.Reader if in, err = cmd.StdinPipe(); err != nil { return @@ -67,6 +77,9 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { if out, err = cmd.StdoutPipe(); err != nil { return } + if stderr, err = cmd.StderrPipe(); err != nil { + return + } f := ForkedSSHReadWriteCloser{ RemoteStdin: in, @@ -77,16 +90,19 @@ func Outgoing(remote SSHTransport) (conn io.ReadWriteCloser, err error) { } f.exitWaitGroup.Add(1) + if err = cmd.Start(); err != nil { + return + } go func() { defer f.exitWaitGroup.Done() - stderr, _ := cmd.StderrPipe() var b bytes.Buffer - if err := cmd.Run(); err != nil { - io.Copy(&b, stderr) - fmt.Println(b.String()) - fmt.Printf("%v\n", cmd.ProcessState) - //panic(err) + if _, err := io.Copy(&b, stderr); err != nil { + panic(err) + } + if err := cmd.Wait(); err != nil { + fmt.Fprintf(os.Stderr, "ssh command exited with error: %v. Stderr:\n%s\n", cmd.ProcessState, b) + //panic(err) TODO } }()