mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-22 08:23:50 +01:00
ssh+stdinserver: dump sshbytestream for github.com/problame/go-netssh
Cleaner abstractions + underlying go-rwccmd package does proper handling of asynchronous exits, etc.
This commit is contained in:
parent
fc1c46ffd7
commit
ccd062e238
14
Gopkg.lock
generated
14
Gopkg.lock
generated
@ -79,6 +79,18 @@
|
||||
revision = "792786c7400a136282c1664665ae0a8db921c6c2"
|
||||
version = "v1.0.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/problame/go-netssh"
|
||||
packages = ["."]
|
||||
revision = "ffa145d2506e222977205e7666a9722d6b9959ac"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/problame/go-rwccmd"
|
||||
packages = ["."]
|
||||
revision = "391d2c78c8404a9683d79f75dd24ab53040f89f7"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/spf13/cobra"
|
||||
@ -106,6 +118,6 @@
|
||||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "bdfcb09b88a5c4de0ea4a4fcd438b31f9feea455ebb17e4d3d4d620d704796e1"
|
||||
inputs-digest = "3c3d8d4a2c6fdd6cff0826338b5e56ba0c95516215b40aea7e421409dc31b14f"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
@ -43,3 +43,7 @@ ignored = [ "github.com/inconshreveable/mousetrap" ]
|
||||
[[constraint]]
|
||||
name = "github.com/go-logfmt/logfmt"
|
||||
version = "*"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/problame/go-rwccmd"
|
||||
branch = "master"
|
||||
|
@ -4,10 +4,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"context"
|
||||
"github.com/jinzhu/copier"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/sshbytestream"
|
||||
"github.com/problame/go-netssh"
|
||||
)
|
||||
|
||||
type SSHStdinserverConnecter struct {
|
||||
@ -34,11 +35,12 @@ func parseSSHStdinserverConnecter(i map[string]interface{}) (c *SSHStdinserverCo
|
||||
}
|
||||
|
||||
func (c *SSHStdinserverConnecter) Connect() (rwc io.ReadWriteCloser, err error) {
|
||||
var rpcTransport sshbytestream.SSHTransport
|
||||
if err = copier.Copy(&rpcTransport, c); err != nil {
|
||||
|
||||
var endpoint netssh.Endpoint
|
||||
if err = copier.Copy(&endpoint, c); err != nil {
|
||||
return
|
||||
}
|
||||
if rwc, err = sshbytestream.Outgoing(rpcTransport); err != nil {
|
||||
if rwc, err = netssh.Dial(context.TODO(), endpoint); err != nil {
|
||||
err = errors.WithStack(err)
|
||||
return
|
||||
}
|
||||
|
@ -1,18 +1,16 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/ftrvxmtrx/fd"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/problame/go-netssh"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
type StdinserverListenerFactory struct {
|
||||
ClientIdentity string `mapstructure:"client_identity"`
|
||||
sockaddr *net.UnixAddr
|
||||
sockpath string
|
||||
}
|
||||
|
||||
func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
||||
@ -27,78 +25,28 @@ func parseStdinserverListenerFactory(c JobParsingContext, i map[string]interface
|
||||
return
|
||||
}
|
||||
|
||||
f.sockaddr, err = stdinserverListenerSocket(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
f.sockpath = path.Join(c.Global.Serve.Stdinserver.SockDir, f.ClientIdentity)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func stdinserverListenerSocket(sockdir, clientIdentity string) (addr *net.UnixAddr, err error) {
|
||||
sockpath := path.Join(sockdir, clientIdentity)
|
||||
addr, err = net.ResolveUnixAddr("unix", sockpath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot resolve unix address")
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
|
||||
|
||||
ul, err := ListenUnixPrivate(f.sockaddr)
|
||||
l, err := netssh.Listen(f.sockpath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", f.sockaddr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := &StdinserverListener{ul}
|
||||
|
||||
return l, nil
|
||||
return StdinserverListener{l}, nil
|
||||
}
|
||||
|
||||
type StdinserverListener struct {
|
||||
l *net.UnixListener
|
||||
l *netssh.Listener
|
||||
}
|
||||
|
||||
type fdRWC struct {
|
||||
stdin, stdout *os.File
|
||||
control *net.UnixConn
|
||||
func (l StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
|
||||
return l.l.Accept()
|
||||
}
|
||||
|
||||
func (f fdRWC) Read(p []byte) (n int, err error) {
|
||||
return f.stdin.Read(p)
|
||||
}
|
||||
|
||||
func (f fdRWC) Write(p []byte) (n int, err error) {
|
||||
return f.stdout.Write(p)
|
||||
}
|
||||
|
||||
func (f fdRWC) Close() (err error) {
|
||||
f.stdin.Close()
|
||||
f.stdout.Close()
|
||||
return f.control.Close()
|
||||
}
|
||||
|
||||
func (l *StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "error accepting on unix listener")
|
||||
return
|
||||
}
|
||||
|
||||
// Read the stdin and stdout of the stdinserver command
|
||||
files, err := fd.Get(c.(*net.UnixConn), 2, []string{"stdin", "stdout"})
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "error receiving fds from stdinserver command")
|
||||
c.Close()
|
||||
}
|
||||
|
||||
rwc := fdRWC{files[0], files[1], c.(*net.UnixConn)}
|
||||
|
||||
return rwc, nil
|
||||
|
||||
}
|
||||
|
||||
func (l *StdinserverListener) Close() (err error) {
|
||||
return l.l.Close() // removes socket file automatically
|
||||
func (l StdinserverListener) Close() (err error) {
|
||||
return l.l.Close()
|
||||
}
|
||||
|
@ -1,14 +1,13 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ftrvxmtrx/fd"
|
||||
"context"
|
||||
"github.com/problame/go-netssh"
|
||||
"github.com/spf13/cobra"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"path"
|
||||
)
|
||||
|
||||
var StdinserverCmd = &cobra.Command{
|
||||
@ -23,71 +22,34 @@ func init() {
|
||||
|
||||
func cmdStdinServer(cmd *cobra.Command, args []string) {
|
||||
|
||||
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
|
||||
defer os.Exit(1)
|
||||
|
||||
die := func() {
|
||||
log.Printf("stdinserver exiting after fatal error")
|
||||
os.Exit(1)
|
||||
}
|
||||
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||
|
||||
conf, err := ParseConfig(rootArgs.configFile)
|
||||
if err != nil {
|
||||
log.Printf("error parsing config: %s", err)
|
||||
die()
|
||||
return
|
||||
}
|
||||
|
||||
if len(args) != 1 || args[0] == "" {
|
||||
err = fmt.Errorf("must specify client_identity as positional argument")
|
||||
die()
|
||||
log.Print("must specify client_identity as positional argument")
|
||||
return
|
||||
}
|
||||
|
||||
identity := args[0]
|
||||
unixaddr := path.Join(conf.Global.Serve.Stdinserver.SockDir, identity)
|
||||
|
||||
unixaddr, err := stdinserverListenerSocket(conf.Global.Serve.Stdinserver.SockDir, identity)
|
||||
if err != nil {
|
||||
log.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
|
||||
|
||||
ctx := netssh.ContextWithLog(context.TODO(), log)
|
||||
|
||||
err = netssh.Proxy(ctx, unixaddr)
|
||||
if err == nil {
|
||||
log.Print("proxying finished successfully, exiting with status 0")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
log.Printf("opening control connection to zrepld via %s", unixaddr)
|
||||
conn, err := net.DialUnix("unix", nil, unixaddr)
|
||||
if err != nil {
|
||||
log.Printf("error connecting to zrepld: %s", err)
|
||||
die()
|
||||
}
|
||||
|
||||
log.Printf("sending stdin and stdout fds to zrepld")
|
||||
err = fd.Put(conn, os.Stdin, os.Stdout)
|
||||
if err != nil {
|
||||
log.Printf("error: %s", err)
|
||||
die()
|
||||
}
|
||||
|
||||
log.Printf("waiting for zrepld to close control connection")
|
||||
for {
|
||||
|
||||
var buf [64]byte
|
||||
n, err := conn.Read(buf[:])
|
||||
if err == nil && n != 0 {
|
||||
log.Printf("protocol error: read expected to timeout or EOF returned bytes")
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
log.Printf("zrepld closed control connection, terminating")
|
||||
break
|
||||
}
|
||||
|
||||
neterr, ok := err.(net.Error)
|
||||
if !ok {
|
||||
log.Printf("received unexpected error type: %T %s", err, err)
|
||||
die()
|
||||
}
|
||||
if !neterr.Timeout() {
|
||||
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
|
||||
die()
|
||||
}
|
||||
// Read timed out, as expected
|
||||
}
|
||||
|
||||
return
|
||||
log.Printf("error proxying: %s", err)
|
||||
|
||||
}
|
||||
|
@ -1,123 +0,0 @@
|
||||
package sshbytestream
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/zrepl/zrepl/util"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Stderr []byte
|
||||
WaitErr error
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
return fmt.Sprintf("ssh command failed with error: %v. stderr:\n%s\n", e.WaitErr, e.Stderr)
|
||||
}
|
||||
|
||||
type SSHTransport struct {
|
||||
Host string
|
||||
User string
|
||||
Port uint16
|
||||
IdentityFile string
|
||||
SSHCommand string
|
||||
Options []string
|
||||
}
|
||||
|
||||
var SSHCommand string = "ssh"
|
||||
var SSHServerAliveInterval uint = 60
|
||||
|
||||
func Incoming() (wc io.ReadWriteCloser, err error) {
|
||||
// derivce ReadWriteCloser from stdin & stdout
|
||||
return IncomingReadWriteCloser{}, nil
|
||||
}
|
||||
|
||||
type IncomingReadWriteCloser struct{}
|
||||
|
||||
func (f IncomingReadWriteCloser) Read(p []byte) (n int, err error) {
|
||||
return os.Stdin.Read(p)
|
||||
}
|
||||
|
||||
func (f IncomingReadWriteCloser) Write(p []byte) (n int, err error) {
|
||||
return os.Stdout.Write(p)
|
||||
}
|
||||
|
||||
func (f IncomingReadWriteCloser) Close() (err error) {
|
||||
if err = os.Stdin.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
if err = os.Stdout.Close(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type OutgoingSSHByteStream struct {
|
||||
c *util.IOCommand
|
||||
}
|
||||
|
||||
func Outgoing(remote SSHTransport) (s OutgoingSSHByteStream, err error) {
|
||||
|
||||
sshArgs := make([]string, 0, 2*len(remote.Options)+4)
|
||||
sshArgs = append(sshArgs,
|
||||
"-p", fmt.Sprintf("%d", remote.Port),
|
||||
"-q",
|
||||
"-i", remote.IdentityFile,
|
||||
"-o", "BatchMode=yes",
|
||||
"-o", fmt.Sprintf("ServerAliveInterval=%d", SSHServerAliveInterval),
|
||||
)
|
||||
for _, option := range remote.Options {
|
||||
sshArgs = append(sshArgs, "-o", option)
|
||||
}
|
||||
sshArgs = append(sshArgs, fmt.Sprintf("%s@%s", remote.User, remote.Host))
|
||||
|
||||
var sshCommand = SSHCommand
|
||||
if len(remote.SSHCommand) > 0 {
|
||||
sshCommand = remote.SSHCommand
|
||||
}
|
||||
|
||||
if s.c, err = util.NewIOCommand(sshCommand, sshArgs, util.IOCommandStderrBufSize); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Clear environment of cmd, ssh shall not rely on SSH_AUTH_SOCK, etc.
|
||||
s.c.Cmd.Env = []string{}
|
||||
|
||||
err = s.c.Start()
|
||||
return
|
||||
}
|
||||
|
||||
func (s OutgoingSSHByteStream) Read(p []byte) (n int, err error) {
|
||||
return s.c.Read(p)
|
||||
}
|
||||
|
||||
func (s OutgoingSSHByteStream) Write(p []byte) (n int, err error) {
|
||||
return s.c.Write(p)
|
||||
}
|
||||
|
||||
func (s OutgoingSSHByteStream) Close() (err error) {
|
||||
err = s.c.Close()
|
||||
if err == nil || s.c.ExitResult == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// SSH catches SIGTERM and has different exit codes on different platforms
|
||||
ws := s.c.ExitResult.WaitStatus
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if ws.ExitStatus() == 128+int(syscall.SIGTERM) { // OpenSSH_7.5p1, OpenSSL 1.1.0f 25 May 2017 Arch Linux
|
||||
err = nil
|
||||
}
|
||||
case "freebsd": // OpenSSH_7.2p2, OpenSSL 1.0.2k-freebsd 26 Jan 2017
|
||||
if ws.ExitStatus() == 255 {
|
||||
err = nil
|
||||
}
|
||||
default: // TODO
|
||||
}
|
||||
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user