ssh+stdinserver: dump sshbytestream for github.com/problame/go-netssh

Cleaner abstractions + underlying go-rwccmd package does proper handling
of asynchronous exits, etc.
This commit is contained in:
Christian Schwarz 2018-02-17 01:08:15 +01:00
parent fc1c46ffd7
commit ccd062e238
6 changed files with 54 additions and 249 deletions

14
Gopkg.lock generated
View File

@ -79,6 +79,18 @@
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
version = "v1.0.0"
[[projects]]
branch = "master"
name = "github.com/problame/go-netssh"
packages = ["."]
revision = "ffa145d2506e222977205e7666a9722d6b9959ac"
[[projects]]
branch = "master"
name = "github.com/problame/go-rwccmd"
packages = ["."]
revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7"
[[projects]]
branch = "master"
name = "github.com/spf13/cobra"
@ -106,6 +118,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "bdfcb09b88a5c4de0ea4a4fcd438b31f9feea455ebb17e4d3d4d620d704796e1"
inputs-digest = "3c3d8d4a2c6fdd6cff0826338b5e56ba0c95516215b40aea7e421409dc31b14f"
solver-name = "gps-cdcl"
solver-version = 1

View File

@ -43,3 +43,7 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
[[constraint]]
name = "github.com/go-logfmt/logfmt"
version = "*"
[[constraint]]
name = "github.com/problame/go-rwccmd"
branch = "master"

View File

@ -4,10 +4,11 @@ import (
"fmt"
"io"
"context"
"github.com/jinzhu/copier"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/zrepl/zrepl/sshbytestream"
"github.com/problame/go-netssh"
)
type SSHStdinserverConnecter struct {
@ -34,11 +35,12 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo
}
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) {
var rpcTransport sshbytestream.SSHTransport
if err = copier.Copy(&rpcTransport, c); err != nil {
var endpoint netssh.Endpoint
if err = copier.Copy(&endpoint, c); err != nil {
return
}
if rwc, err = sshbytestream.Outgoing(rpcTransport); err != nil {
if rwc, err = netssh.Dial(context.TODO(), endpoint); err != nil {
err = errors.WithStack(err)
return
}

View File

@ -1,18 +1,16 @@
package cmd
import (
"github.com/ftrvxmtrx/fd"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/problame/go-netssh"
"io"
"net"
"os"
"path"
)
type StdinserverListenerFactory struct {
ClientIdentity string `mapstructure:"client_identity"`
sockaddr *net.UnixAddr
sockpath string
}
func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
@ -27,78 +25,28 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface
return
}
f.sockaddr, err = stdinserverListenerSocket(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
if err != nil {
return
}
f.sockpath = path.Join(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
return
}
func stdinserverListenerSocket(sockdir, clientIdentity string) (addr *net.UnixAddr, err error) {
sockpath := path.Join(sockdir, clientIdentity)
addr, err = net.ResolveUnixAddr("unix", sockpath)
if err != nil {
return nil, errors.Wrap(err, "cannot resolve unix address")
}
return addr, nil
}
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
ul, err := ListenUnixPrivate(f.sockaddr)
l, err := netssh.Listen(f.sockpath)
if err != nil {
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", f.sockaddr)
return nil, err
}
l := &StdinserverListener{ul}
return l, nil
return StdinserverListener{l}, nil
}
type StdinserverListener struct {
l *net.UnixListener
l *netssh.Listener
}
type fdRWC struct {
stdin, stdout *os.File
control *net.UnixConn
func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
return l.l.Accept()
}
func (f fdRWC) Read(p []byte) (n int, err error) {
return f.stdin.Read(p)
}
func (f fdRWC) Write(p []byte) (n int, err error) {
return f.stdout.Write(p)
}
func (f fdRWC) Close() (err error) {
f.stdin.Close()
f.stdout.Close()
return f.control.Close()
}
func (l *StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
c, err := l.l.Accept()
if err != nil {
err = errors.Wrap(err, "error accepting on unix listener")
return
}
// Read the stdin and stdout of the stdinserver command
files, err := fd.Get(c.(*net.UnixConn), 2, []string{"stdin", "stdout"})
if err != nil {
err = errors.Wrap(err, "error receiving fds from stdinserver command")
c.Close()
}
rwc := fdRWC{files[0], files[1], c.(*net.UnixConn)}
return rwc, nil
}
func (l *StdinserverListener) Close() (err error) {
return l.l.Close() // removes socket file automatically
func (l StdinserverListener) Close() (err error) {
return l.l.Close()
}

View File

@ -1,14 +1,13 @@
package cmd
import (
"fmt"
"os"
"github.com/ftrvxmtrx/fd"
"context"
"github.com/problame/go-netssh"
"github.com/spf13/cobra"
"io"
"log"
"net"
"path"
)
var StdinserverCmd = &cobra.Command{
@ -23,71 +22,34 @@ func init() {
func cmdStdinServer(cmd *cobra.Command, args []string) {
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
defer os.Exit(1)
die := func() {
log.Printf("stdinserver exiting after fatal error")
os.Exit(1)
}
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
conf, err := ParseConfig(rootArgs.configFile)
if err != nil {
log.Printf("error parsing config: %s", err)
die()
return
}
if len(args) != 1 || args[0] == "" {
err = fmt.Errorf("must specify client_identity as positional argument")
die()
log.Print("must specify client_identity as positional argument")
return
}
identity := args[0]
unixaddr := path.Join(conf.Global.Serve.Stdinserver.SockDir, identity)
unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity)
if err != nil {
log.Printf("%s", err)
os.Exit(1)
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
ctx := netssh.ContextWithLog(context.TODO(), log)
err = netssh.Proxy(ctx, unixaddr)
if err == nil {
log.Print("proxying finished successfully, exiting with status 0")
os.Exit(0)
}
log.Printf("opening control connection to zrepld via %s", unixaddr)
conn, err := net.DialUnix("unix", nil, unixaddr)
if err != nil {
log.Printf("error connecting to zrepld: %s", err)
die()
}
log.Printf("sending stdin and stdout fds to zrepld")
err = fd.Put(conn, os.Stdin, os.Stdout)
if err != nil {
log.Printf("error: %s", err)
die()
}
log.Printf("waiting for zrepld to close control connection")
for {
var buf [64]byte
n, err := conn.Read(buf[:])
if err == nil && n != 0 {
log.Printf("protocol error: read expected to timeout or EOF returned bytes")
}
if err == io.EOF {
log.Printf("zrepld closed control connection, terminating")
break
}
neterr, ok := err.(net.Error)
if !ok {
log.Printf("received unexpected error type: %T %s", err, err)
die()
}
if !neterr.Timeout() {
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
die()
}
// Read timed out, as expected
}
return
log.Printf("error proxying: %s", err)
}

View File

@ -1,123 +0,0 @@
package sshbytestream
import (
"fmt"
"io"
"os"
"runtime"
"syscall"
"github.com/zrepl/zrepl/util"
)
type Error struct {
Stderr []byte
WaitErr error
}
func (e Error) Error() string {
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
}
type SSHTransport struct {
Host string
User string
Port uint16
IdentityFile string
SSHCommand string
Options []string
}
var SSHCommand string = "ssh"
var SSHServerAliveInterval uint = 60
func Incoming() (wc io.ReadWriteCloser, err error) {
// derivce ReadWriteCloser from stdin & stdout
return IncomingReadWriteCloser{}, nil
}
type IncomingReadWriteCloser struct{}
func (f IncomingReadWriteCloser) Read(p []byte) (n int, err error) {
return os.Stdin.Read(p)
}
func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
return os.Stdout.Write(p)
}
func (f IncomingReadWriteCloser) Close() (err error) {
if err = os.Stdin.Close(); err != nil {
return
}
if err = os.Stdout.Close(); err != nil {
return
}
return
}
type OutgoingSSHByteStream struct {
c *util.IOCommand
}
func Outgoing(remote SSHTransport) (s OutgoingSSHByteStream, err error) {
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
sshArgs = append(sshArgs,
"-p", fmt.Sprintf("%d", remote.Port),
"-q",
"-i", remote.IdentityFile,
"-o", "BatchMode=yes",
"-o", fmt.Sprintf("ServerAliveInterval=%d", SSHServerAliveInterval),
)
for _, option := range remote.Options {
sshArgs = append(sshArgs, "-o", option)
}
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
var sshCommand = SSHCommand
if len(remote.SSHCommand) > 0 {
sshCommand = remote.SSHCommand
}
if s.c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil {
return
}
// Clear environment of cmd, ssh shall not rely on SSH_AUTH_SOCK, etc.
s.c.Cmd.Env = []string{}
err = s.c.Start()
return
}
func (s OutgoingSSHByteStream) Read(p []byte) (n int, err error) {
return s.c.Read(p)
}
func (s OutgoingSSHByteStream) Write(p []byte) (n int, err error) {
return s.c.Write(p)
}
func (s OutgoingSSHByteStream) Close() (err error) {
err = s.c.Close()
if err == nil || s.c.ExitResult == nil {
return
}
// SSH catches SIGTERM and has different exit codes on different platforms
ws := s.c.ExitResult.WaitStatus
switch runtime.GOOS {
case "linux":
if ws.ExitStatus() == 128+int(syscall.SIGTERM) { // OpenSSH_7.5p1, OpenSSL 1.1.0f 25 May 2017 Arch Linux
err = nil
}
case "freebsd": // OpenSSH_7.2p2, OpenSSL 1.0.2k-freebsd 26 Jan 2017
if ws.ExitStatus() == 255 {
err = nil
}
default: // TODO
}
return
}