mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-29 03:45:27 +01:00
ssh+stdinserver: dump sshbytestream for github.com/problame/go-netssh
Cleaner abstractions + underlying go-rwccmd package does proper handling of asynchronous exits, etc.
This commit is contained in:
parent
fc1c46ffd7
commit
ccd062e238
14
Gopkg.lock
generated
14
Gopkg.lock
generated
@ -79,6 +79,18 @@
|
|||||||
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
|
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
|
||||||
version = "v1.0.0"
|
version = "v1.0.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "github.com/problame/go-netssh"
|
||||||
|
packages = ["."]
|
||||||
|
revision = "ffa145d2506e222977205e7666a9722d6b9959ac"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "github.com/problame/go-rwccmd"
|
||||||
|
packages = ["."]
|
||||||
|
revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
name = "github.com/spf13/cobra"
|
name = "github.com/spf13/cobra"
|
||||||
@ -106,6 +118,6 @@
|
|||||||
[solve-meta]
|
[solve-meta]
|
||||||
analyzer-name = "dep"
|
analyzer-name = "dep"
|
||||||
analyzer-version = 1
|
analyzer-version = 1
|
||||||
inputs-digest = "bdfcb09b88a5c4de0ea4a4fcd438b31f9feea455ebb17e4d3d4d620d704796e1"
|
inputs-digest = "3c3d8d4a2c6fdd6cff0826338b5e56ba0c95516215b40aea7e421409dc31b14f"
|
||||||
solver-name = "gps-cdcl"
|
solver-name = "gps-cdcl"
|
||||||
solver-version = 1
|
solver-version = 1
|
||||||
|
@ -43,3 +43,7 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
|
|||||||
[[constraint]]
|
[[constraint]]
|
||||||
name = "github.com/go-logfmt/logfmt"
|
name = "github.com/go-logfmt/logfmt"
|
||||||
version = "*"
|
version = "*"
|
||||||
|
|
||||||
|
[[constraint]]
|
||||||
|
name = "github.com/problame/go-rwccmd"
|
||||||
|
branch = "master"
|
||||||
|
@ -4,10 +4,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"context"
|
||||||
"github.com/jinzhu/copier"
|
"github.com/jinzhu/copier"
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/zrepl/zrepl/sshbytestream"
|
"github.com/problame/go-netssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SSHStdinserverConnecter struct {
|
type SSHStdinserverConnecter struct {
|
||||||
@ -34,11 +35,12 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) {
|
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) {
|
||||||
var rpcTransport sshbytestream.SSHTransport
|
|
||||||
if err = copier.Copy(&rpcTransport, c); err != nil {
|
var endpoint netssh.Endpoint
|
||||||
|
if err = copier.Copy(&endpoint, c); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if rwc, err = sshbytestream.Outgoing(rpcTransport); err != nil {
|
if rwc, err = netssh.Dial(context.TODO(), endpoint); err != nil {
|
||||||
err = errors.WithStack(err)
|
err = errors.WithStack(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,16 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/ftrvxmtrx/fd"
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/problame/go-netssh"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"path"
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StdinserverListenerFactory struct {
|
type StdinserverListenerFactory struct {
|
||||||
ClientIdentity string `mapstructure:"client_identity"`
|
ClientIdentity string `mapstructure:"client_identity"`
|
||||||
sockaddr *net.UnixAddr
|
sockpath string
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
||||||
@ -27,78 +25,28 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.sockaddr, err = stdinserverListenerSocket(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
|
f.sockpath = path.Join(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func stdinserverListenerSocket(sockdir, clientIdentity string) (addr *net.UnixAddr, err error) {
|
|
||||||
sockpath := path.Join(sockdir, clientIdentity)
|
|
||||||
addr, err = net.ResolveUnixAddr("unix", sockpath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "cannot resolve unix address")
|
|
||||||
}
|
|
||||||
return addr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
|
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
|
||||||
|
|
||||||
ul, err := ListenUnixPrivate(f.sockaddr)
|
l, err := netssh.Listen(f.sockpath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", f.sockaddr)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return StdinserverListener{l}, nil
|
||||||
l := &StdinserverListener{ul}
|
|
||||||
|
|
||||||
return l, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StdinserverListener struct {
|
type StdinserverListener struct {
|
||||||
l *net.UnixListener
|
l *netssh.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
type fdRWC struct {
|
func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
|
||||||
stdin, stdout *os.File
|
return l.l.Accept()
|
||||||
control *net.UnixConn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f fdRWC) Read(p []byte) (n int, err error) {
|
func (l StdinserverListener) Close() (err error) {
|
||||||
return f.stdin.Read(p)
|
return l.l.Close()
|
||||||
}
|
|
||||||
|
|
||||||
func (f fdRWC) Write(p []byte) (n int, err error) {
|
|
||||||
return f.stdout.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f fdRWC) Close() (err error) {
|
|
||||||
f.stdin.Close()
|
|
||||||
f.stdout.Close()
|
|
||||||
return f.control.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
|
|
||||||
c, err := l.l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
err = errors.Wrap(err, "error accepting on unix listener")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the stdin and stdout of the stdinserver command
|
|
||||||
files, err := fd.Get(c.(*net.UnixConn), 2, []string{"stdin", "stdout"})
|
|
||||||
if err != nil {
|
|
||||||
err = errors.Wrap(err, "error receiving fds from stdinserver command")
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
rwc := fdRWC{files[0], files[1], c.(*net.UnixConn)}
|
|
||||||
|
|
||||||
return rwc, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *StdinserverListener) Close() (err error) {
|
|
||||||
return l.l.Close() // removes socket file automatically
|
|
||||||
}
|
}
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"github.com/ftrvxmtrx/fd"
|
"context"
|
||||||
|
"github.com/problame/go-netssh"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
var StdinserverCmd = &cobra.Command{
|
var StdinserverCmd = &cobra.Command{
|
||||||
@ -23,71 +22,34 @@ func init() {
|
|||||||
|
|
||||||
func cmdStdinServer(cmd *cobra.Command, args []string) {
|
func cmdStdinServer(cmd *cobra.Command, args []string) {
|
||||||
|
|
||||||
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
|
||||||
|
defer os.Exit(1)
|
||||||
|
|
||||||
die := func() {
|
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||||
log.Printf("stdinserver exiting after fatal error")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
conf, err := ParseConfig(rootArgs.configFile)
|
conf, err := ParseConfig(rootArgs.configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error parsing config: %s", err)
|
log.Printf("error parsing config: %s", err)
|
||||||
die()
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(args) != 1 || args[0] == "" {
|
if len(args) != 1 || args[0] == "" {
|
||||||
err = fmt.Errorf("must specify client_identity as positional argument")
|
log.Print("must specify client_identity as positional argument")
|
||||||
die()
|
|
||||||
}
|
|
||||||
identity := args[0]
|
|
||||||
|
|
||||||
unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("%s", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("opening control connection to zrepld via %s", unixaddr)
|
|
||||||
conn, err := net.DialUnix("unix", nil, unixaddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("error connecting to zrepld: %s", err)
|
|
||||||
die()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("sending stdin and stdout fds to zrepld")
|
|
||||||
err = fd.Put(conn, os.Stdin, os.Stdout)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("error: %s", err)
|
|
||||||
die()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("waiting for zrepld to close control connection")
|
|
||||||
for {
|
|
||||||
|
|
||||||
var buf [64]byte
|
|
||||||
n, err := conn.Read(buf[:])
|
|
||||||
if err == nil && n != 0 {
|
|
||||||
log.Printf("protocol error: read expected to timeout or EOF returned bytes")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == io.EOF {
|
|
||||||
log.Printf("zrepld closed control connection, terminating")
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
neterr, ok := err.(net.Error)
|
|
||||||
if !ok {
|
|
||||||
log.Printf("received unexpected error type: %T %s", err, err)
|
|
||||||
die()
|
|
||||||
}
|
|
||||||
if !neterr.Timeout() {
|
|
||||||
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
|
|
||||||
die()
|
|
||||||
}
|
|
||||||
// Read timed out, as expected
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
identity := args[0]
|
||||||
|
unixaddr := path.Join(conf.Global.Serve.Stdinserver.SockDir, identity)
|
||||||
|
|
||||||
|
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
|
||||||
|
|
||||||
|
ctx := netssh.ContextWithLog(context.TODO(), log)
|
||||||
|
|
||||||
|
err = netssh.Proxy(ctx, unixaddr)
|
||||||
|
if err == nil {
|
||||||
|
log.Print("proxying finished successfully, exiting with status 0")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
log.Printf("error proxying: %s", err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,123 +0,0 @@
|
|||||||
package sshbytestream
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/zrepl/zrepl/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Error struct {
|
|
||||||
Stderr []byte
|
|
||||||
WaitErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e Error) Error() string {
|
|
||||||
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SSHTransport struct {
|
|
||||||
Host string
|
|
||||||
User string
|
|
||||||
Port uint16
|
|
||||||
IdentityFile string
|
|
||||||
SSHCommand string
|
|
||||||
Options []string
|
|
||||||
}
|
|
||||||
|
|
||||||
var SSHCommand string = "ssh"
|
|
||||||
var SSHServerAliveInterval uint = 60
|
|
||||||
|
|
||||||
func Incoming() (wc io.ReadWriteCloser, err error) {
|
|
||||||
// derivce ReadWriteCloser from stdin & stdout
|
|
||||||
return IncomingReadWriteCloser{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type IncomingReadWriteCloser struct{}
|
|
||||||
|
|
||||||
func (f IncomingReadWriteCloser) Read(p []byte) (n int, err error) {
|
|
||||||
return os.Stdin.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
|
|
||||||
return os.Stdout.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f IncomingReadWriteCloser) Close() (err error) {
|
|
||||||
if err = os.Stdin.Close(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = os.Stdout.Close(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type OutgoingSSHByteStream struct {
|
|
||||||
c *util.IOCommand
|
|
||||||
}
|
|
||||||
|
|
||||||
func Outgoing(remote SSHTransport) (s OutgoingSSHByteStream, err error) {
|
|
||||||
|
|
||||||
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
|
|
||||||
sshArgs = append(sshArgs,
|
|
||||||
"-p", fmt.Sprintf("%d", remote.Port),
|
|
||||||
"-q",
|
|
||||||
"-i", remote.IdentityFile,
|
|
||||||
"-o", "BatchMode=yes",
|
|
||||||
"-o", fmt.Sprintf("ServerAliveInterval=%d", SSHServerAliveInterval),
|
|
||||||
)
|
|
||||||
for _, option := range remote.Options {
|
|
||||||
sshArgs = append(sshArgs, "-o", option)
|
|
||||||
}
|
|
||||||
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
|
|
||||||
|
|
||||||
var sshCommand = SSHCommand
|
|
||||||
if len(remote.SSHCommand) > 0 {
|
|
||||||
sshCommand = remote.SSHCommand
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear environment of cmd, ssh shall not rely on SSH_AUTH_SOCK, etc.
|
|
||||||
s.c.Cmd.Env = []string{}
|
|
||||||
|
|
||||||
err = s.c.Start()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s OutgoingSSHByteStream) Read(p []byte) (n int, err error) {
|
|
||||||
return s.c.Read(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s OutgoingSSHByteStream) Write(p []byte) (n int, err error) {
|
|
||||||
return s.c.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s OutgoingSSHByteStream) Close() (err error) {
|
|
||||||
err = s.c.Close()
|
|
||||||
if err == nil || s.c.ExitResult == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSH catches SIGTERM and has different exit codes on different platforms
|
|
||||||
ws := s.c.ExitResult.WaitStatus
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "linux":
|
|
||||||
if ws.ExitStatus() == 128+int(syscall.SIGTERM) { // OpenSSH_7.5p1, OpenSSL 1.1.0f 25 May 2017 Arch Linux
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
case "freebsd": // OpenSSH_7.2p2, OpenSSL 1.0.2k-freebsd 26 Jan 2017
|
|
||||||
if ws.ExitStatus() == 255 {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
default: // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user