diff --git a/Gopkg.lock b/Gopkg.lock index fabfd9d..078dc5a 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -79,6 +79,18 @@ revision = "792786c7400a136282c1664665ae0a8db921c6c2" version = "v1.0.0" +[[projects]] + branch = "master" + name = "github.com/problame/go-netssh" + packages = ["."] + revision = "ffa145d2506e222977205e7666a9722d6b9959ac" + +[[projects]] + branch = "master" + name = "github.com/problame/go-rwccmd" + packages = ["."] + revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7" + [[projects]] branch = "master" name = "github.com/spf13/cobra" @@ -106,6 +118,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "bdfcb09b88a5c4de0ea4a4fcd438b31f9feea455ebb17e4d3d4d620d704796e1" + inputs-digest = "3c3d8d4a2c6fdd6cff0826338b5e56ba0c95516215b40aea7e421409dc31b14f" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index c9a3abb..3b1d3ac 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -43,3 +43,7 @@ ignored = [ "github.com/inconshreveable/mousetrap" ] [[constraint]] name = "github.com/go-logfmt/logfmt" version = "*" + +[[constraint]] + name = "github.com/problame/go-rwccmd" + branch = "master" diff --git a/cmd/config_connect.go b/cmd/config_connect.go index 35be638..60bbc8d 100644 --- a/cmd/config_connect.go +++ b/cmd/config_connect.go @@ -4,10 +4,11 @@ import ( "fmt" "io" + "context" "github.com/jinzhu/copier" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" - "github.com/zrepl/zrepl/sshbytestream" + "github.com/problame/go-netssh" ) type SSHStdinserverConnecter struct { @@ -34,11 +35,12 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo } func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) { - var rpcTransport sshbytestream.SSHTransport - if err = copier.Copy(&rpcTransport, c); err != nil { + + var endpoint netssh.Endpoint + if err = copier.Copy(&endpoint, c); err != nil { return } - if rwc, err = sshbytestream.Outgoing(rpcTransport); err != nil { + if rwc, err = netssh.Dial(context.TODO(), endpoint); err != nil { err = errors.WithStack(err) return } diff --git a/cmd/config_serve_stdinserver.go b/cmd/config_serve_stdinserver.go index 0301ae3..896b4aa 100644 --- a/cmd/config_serve_stdinserver.go +++ b/cmd/config_serve_stdinserver.go @@ -1,18 +1,16 @@ package cmd import ( - "github.com/ftrvxmtrx/fd" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" + "github.com/problame/go-netssh" "io" - "net" - "os" "path" ) type StdinserverListenerFactory struct { ClientIdentity string `mapstructure:"client_identity"` - sockaddr *net.UnixAddr + sockpath string } func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) { @@ -27,78 +25,28 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface return } - f.sockaddr, err = stdinserverListenerSocket(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity) - if err != nil { - return - } + f.sockpath = path.Join(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity) return } -func stdinserverListenerSocket(sockdir, clientIdentity string) (addr *net.UnixAddr, err error) { - sockpath := path.Join(sockdir, clientIdentity) - addr, err = net.ResolveUnixAddr("unix", sockpath) - if err != nil { - return nil, errors.Wrap(err, "cannot resolve unix address") - } - return addr, nil -} - func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) { - ul, err := ListenUnixPrivate(f.sockaddr) + l, err := netssh.Listen(f.sockpath) if err != nil { - return nil, errors.Wrapf(err, "cannot listen on unix socket %s", f.sockaddr) + return nil, err } - - l := &StdinserverListener{ul} - - return l, nil + return StdinserverListener{l}, nil } type StdinserverListener struct { - l *net.UnixListener + l *netssh.Listener } -type fdRWC struct { - stdin, stdout *os.File - control *net.UnixConn +func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) { + return l.l.Accept() } -func (f fdRWC) Read(p []byte) (n int, err error) { - return f.stdin.Read(p) -} - -func (f fdRWC) Write(p []byte) (n int, err error) { - return f.stdout.Write(p) -} - -func (f fdRWC) Close() (err error) { - f.stdin.Close() - f.stdout.Close() - return f.control.Close() -} - -func (l *StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) { - c, err := l.l.Accept() - if err != nil { - err = errors.Wrap(err, "error accepting on unix listener") - return - } - - // Read the stdin and stdout of the stdinserver command - files, err := fd.Get(c.(*net.UnixConn), 2, []string{"stdin", "stdout"}) - if err != nil { - err = errors.Wrap(err, "error receiving fds from stdinserver command") - c.Close() - } - - rwc := fdRWC{files[0], files[1], c.(*net.UnixConn)} - - return rwc, nil - -} - -func (l *StdinserverListener) Close() (err error) { - return l.l.Close() // removes socket file automatically +func (l StdinserverListener) Close() (err error) { + return l.l.Close() } diff --git a/cmd/stdinserver.go b/cmd/stdinserver.go index 58dd70c..6d5d41e 100644 --- a/cmd/stdinserver.go +++ b/cmd/stdinserver.go @@ -1,14 +1,13 @@ package cmd import ( - "fmt" "os" - "github.com/ftrvxmtrx/fd" + "context" + "github.com/problame/go-netssh" "github.com/spf13/cobra" - "io" "log" - "net" + "path" ) var StdinserverCmd = &cobra.Command{ @@ -23,71 +22,34 @@ func init() { func cmdStdinServer(cmd *cobra.Command, args []string) { - log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime) + // NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong + defer os.Exit(1) - die := func() { - log.Printf("stdinserver exiting after fatal error") - os.Exit(1) - } + log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime) conf, err := ParseConfig(rootArgs.configFile) if err != nil { log.Printf("error parsing config: %s", err) - die() + return } if len(args) != 1 || args[0] == "" { - err = fmt.Errorf("must specify client_identity as positional argument") - die() + log.Print("must specify client_identity as positional argument") + return } + identity := args[0] + unixaddr := path.Join(conf.Global.Serve.Stdinserver.SockDir, identity) - unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity) - if err != nil { - log.Printf("%s", err) - os.Exit(1) + log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr) + + ctx := netssh.ContextWithLog(context.TODO(), log) + + err = netssh.Proxy(ctx, unixaddr) + if err == nil { + log.Print("proxying finished successfully, exiting with status 0") + os.Exit(0) } - - log.Printf("opening control connection to zrepld via %s", unixaddr) - conn, err := net.DialUnix("unix", nil, unixaddr) - if err != nil { - log.Printf("error connecting to zrepld: %s", err) - die() - } - - log.Printf("sending stdin and stdout fds to zrepld") - err = fd.Put(conn, os.Stdin, os.Stdout) - if err != nil { - log.Printf("error: %s", err) - die() - } - - log.Printf("waiting for zrepld to close control connection") - for { - - var buf [64]byte - n, err := conn.Read(buf[:]) - if err == nil && n != 0 { - log.Printf("protocol error: read expected to timeout or EOF returned bytes") - } - - if err == io.EOF { - log.Printf("zrepld closed control connection, terminating") - break - } - - neterr, ok := err.(net.Error) - if !ok { - log.Printf("received unexpected error type: %T %s", err, err) - die() - } - if !neterr.Timeout() { - log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr) - die() - } - // Read timed out, as expected - } - - return + log.Printf("error proxying: %s", err) } diff --git a/sshbytestream/ssh.go b/sshbytestream/ssh.go deleted file mode 100644 index ef155cc..0000000 --- a/sshbytestream/ssh.go +++ /dev/null @@ -1,123 +0,0 @@ -package sshbytestream - -import ( - "fmt" - "io" - "os" - "runtime" - "syscall" - - "github.com/zrepl/zrepl/util" -) - -type Error struct { - Stderr []byte - WaitErr error -} - -func (e Error) Error() string { - return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr) -} - -type SSHTransport struct { - Host string - User string - Port uint16 - IdentityFile string - SSHCommand string - Options []string -} - -var SSHCommand string = "ssh" -var SSHServerAliveInterval uint = 60 - -func Incoming() (wc io.ReadWriteCloser, err error) { - // derivce ReadWriteCloser from stdin & stdout - return IncomingReadWriteCloser{}, nil -} - -type IncomingReadWriteCloser struct{} - -func (f IncomingReadWriteCloser) Read(p []byte) (n int, err error) { - return os.Stdin.Read(p) -} - -func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) { - return os.Stdout.Write(p) -} - -func (f IncomingReadWriteCloser) Close() (err error) { - if err = os.Stdin.Close(); err != nil { - return - } - if err = os.Stdout.Close(); err != nil { - return - } - return -} - -type OutgoingSSHByteStream struct { - c *util.IOCommand -} - -func Outgoing(remote SSHTransport) (s OutgoingSSHByteStream, err error) { - - sshArgs := make([]string, 0, 2*len(remote.Options)+4) - sshArgs = append(sshArgs, - "-p", fmt.Sprintf("%d", remote.Port), - "-q", - "-i", remote.IdentityFile, - "-o", "BatchMode=yes", - "-o", fmt.Sprintf("ServerAliveInterval=%d", SSHServerAliveInterval), - ) - for _, option := range remote.Options { - sshArgs = append(sshArgs, "-o", option) - } - sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host)) - - var sshCommand = SSHCommand - if len(remote.SSHCommand) > 0 { - sshCommand = remote.SSHCommand - } - - if s.c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil { - return - } - - // Clear environment of cmd, ssh shall not rely on SSH_AUTH_SOCK, etc. - s.c.Cmd.Env = []string{} - - err = s.c.Start() - return -} - -func (s OutgoingSSHByteStream) Read(p []byte) (n int, err error) { - return s.c.Read(p) -} - -func (s OutgoingSSHByteStream) Write(p []byte) (n int, err error) { - return s.c.Write(p) -} - -func (s OutgoingSSHByteStream) Close() (err error) { - err = s.c.Close() - if err == nil || s.c.ExitResult == nil { - return - } - - // SSH catches SIGTERM and has different exit codes on different platforms - ws := s.c.ExitResult.WaitStatus - switch runtime.GOOS { - case "linux": - if ws.ExitStatus() == 128+int(syscall.SIGTERM) { // OpenSSH_7.5p1, OpenSSL 1.1.0f 25 May 2017 Arch Linux - err = nil - } - case "freebsd": // OpenSSH_7.2p2, OpenSSL 1.0.2k-freebsd 26 Jan 2017 - if ws.ExitStatus() == 255 { - err = nil - } - default: // TODO - } - - return -}