mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-22 00:13:52 +01:00
implement stdinserver command + corresponding server
How it works: `zrepl stdinserver CLIENT_IDENTITY` * connects to the socket in $global.serve.stdinserver.sockdir/CLIENT_IDENTITY * sends its stdin / stdout file descriptors to the `zrepl daemon` process (see cmsg(3)) * does nothing more This enables a setup where `zrepl daemon` is not directly exposed to the internet but instead all traffic is tunnelled through SSH. The server with the source job has an authorized_keys file entry for the public key used by the corresponding pull job command="/mnt/zrepl stdinserver CLIENT_IDENTITY" ssh-ed25519 AAAAC3NzaC1E... zrepl@pullingserver
This commit is contained in:
parent
f3689563b5
commit
ce25c01c7e
@ -24,11 +24,12 @@ type RPCConnecter interface {
|
||||
Connect() (rpc.RPCClient, error)
|
||||
}
|
||||
type AuthenticatedChannelListenerFactory interface {
|
||||
Listen() AuthenticatedChannelListener
|
||||
Listen() (AuthenticatedChannelListener, error)
|
||||
}
|
||||
|
||||
type AuthenticatedChannelListener interface {
|
||||
Accept() (ch io.ReadWriteCloser, err error)
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
type SSHStdinServerConnectDescr struct {
|
||||
@ -37,4 +38,3 @@ type SSHStdinServerConnectDescr struct {
|
||||
type PrunePolicy interface {
|
||||
Prune(fs zfs.DatasetPath, versions []zfs.FilesystemVersion) (keep, remote []zfs.FilesystemVersion, err error)
|
||||
}
|
||||
|
||||
|
@ -1,27 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import "github.com/pkg/errors"
|
||||
|
||||
type StdinserverListenerFactory struct {
|
||||
ClientIdentity string
|
||||
}
|
||||
|
||||
func (StdinserverListenerFactory) Listen() AuthenticatedChannelListener {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func parseStdinserverListenerFactory(i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
||||
|
||||
ci, ok := i["client_identity"]
|
||||
if !ok {
|
||||
err = errors.Errorf("must specify 'client_identity'")
|
||||
return
|
||||
}
|
||||
cs, ok := ci.(string)
|
||||
if !ok {
|
||||
err = errors.Errorf("must specify 'client_identity' as string, got %T", cs)
|
||||
return
|
||||
}
|
||||
f = &StdinserverListenerFactory{ClientIdentity: cs}
|
||||
return
|
||||
}
|
123
cmd/config_serve_stdinserver.go
Normal file
123
cmd/config_serve_stdinserver.go
Normal file
@ -0,0 +1,123 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/ftrvxmtrx/fd"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zrepl/zrepl/util"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type StdinserverListenerFactory struct {
|
||||
ClientIdentity string `mapstructure:"client_identity"`
|
||||
ConnLogReadFile string `mapstructure:"connlog_read_file"`
|
||||
ConnLogWriteFile string `mapstructure:"connlog_write_file"`
|
||||
}
|
||||
|
||||
func parseStdinserverListenerFactory(i map[string]interface{}) (f *StdinserverListenerFactory, err error) {
|
||||
|
||||
f = &StdinserverListenerFactory{}
|
||||
|
||||
if err = mapstructure.Decode(i, f); err != nil {
|
||||
return nil, errors.Wrap(err, "mapstructure error")
|
||||
}
|
||||
if !(len(f.ClientIdentity) > 0) {
|
||||
err = errors.Errorf("must specify 'client_identity'")
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func stdinserverListenerSockpath(clientIdentity string) (addr *net.UnixAddr, err error) {
|
||||
sockpath := path.Join(conf.Global.Serve.Stdinserver.SockDir, clientIdentity)
|
||||
addr, err = net.ResolveUnixAddr("unix", sockpath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot resolve unix address")
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func (f *StdinserverListenerFactory) Listen() (al AuthenticatedChannelListener, err error) {
|
||||
|
||||
unixaddr, err := stdinserverListenerSockpath(f.ClientIdentity)
|
||||
|
||||
sockdir := filepath.Dir(unixaddr.Name)
|
||||
sdstat, err := os.Stat(sockdir)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "cannot stat(2) sockdir '%s'", sockdir)
|
||||
}
|
||||
if !sdstat.IsDir() {
|
||||
return nil, errors.Errorf("sockdir is not a directory: %s", sockdir)
|
||||
}
|
||||
p := sdstat.Mode().Perm()
|
||||
if p&0007 != 0 {
|
||||
return nil, errors.Errorf("sockdir must not be world-accessible (permissions are %#o)", p)
|
||||
}
|
||||
|
||||
ul, err := net.ListenUnix("unix", unixaddr) // TODO
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "cannot listen on unix socket %s", unixaddr)
|
||||
}
|
||||
|
||||
l := &StdinserverListener{ul, f.ConnLogReadFile, f.ConnLogWriteFile}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type StdinserverListener struct {
|
||||
l *net.UnixListener
|
||||
ConnLogReadFile string `mapstructure:"connlog_read_file"`
|
||||
ConnLogWriteFile string `mapstructure:"connlog_write_file"`
|
||||
}
|
||||
|
||||
type fdRWC struct {
|
||||
stdin, stdout *os.File
|
||||
control *net.UnixConn
|
||||
}
|
||||
|
||||
func (f fdRWC) Read(p []byte) (n int, err error) {
|
||||
return f.stdin.Read(p)
|
||||
}
|
||||
|
||||
func (f fdRWC) Write(p []byte) (n int, err error) {
|
||||
return f.stdout.Write(p)
|
||||
}
|
||||
|
||||
func (f fdRWC) Close() (err error) {
|
||||
f.stdin.Close()
|
||||
f.stdout.Close()
|
||||
return f.control.Close()
|
||||
}
|
||||
|
||||
func (l *StdinserverListener) Accept() (ch io.ReadWriteCloser, err error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "error accepting on unix listener")
|
||||
return
|
||||
}
|
||||
|
||||
// Read the stdin and stdout of the stdinserver command
|
||||
files, err := fd.Get(c.(*net.UnixConn), 2, []string{"stdin", "stdout"})
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "error receiving fds from stdinserver command")
|
||||
c.Close()
|
||||
}
|
||||
|
||||
rwc := fdRWC{files[0], files[1], c.(*net.UnixConn)}
|
||||
|
||||
rwclog, err := util.NewReadWriteCloserLogger(rwc, l.ConnLogReadFile, l.ConnLogWriteFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return rwclog, nil
|
||||
|
||||
}
|
||||
|
||||
func (l *StdinserverListener) Close() (err error) {
|
||||
return l.l.Close() // removes socket file automatically
|
||||
}
|
87
cmd/stdinserver.go
Normal file
87
cmd/stdinserver.go
Normal file
@ -0,0 +1,87 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ftrvxmtrx/fd"
|
||||
"github.com/spf13/cobra"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
var StdinserverCmd = &cobra.Command{
|
||||
Use: "stdinserver CLIENT_IDENTITY",
|
||||
Short: "start in stdinserver mode (from authorized_keys file)",
|
||||
Run: cmdStdinServer,
|
||||
}
|
||||
|
||||
func init() {
|
||||
RootCmd.AddCommand(StdinserverCmd)
|
||||
}
|
||||
|
||||
func cmdStdinServer(cmd *cobra.Command, args []string) {
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.Printf("stdinserver exiting with error: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(args) != 1 || args[0] == "" {
|
||||
err = fmt.Errorf("must specify client_identity as positional argument")
|
||||
return
|
||||
}
|
||||
identity := args[0]
|
||||
|
||||
unixaddr, err := stdinserverListenerSockpath(identity)
|
||||
if err != nil {
|
||||
log.Printf("%s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Printf("opening control connection to zrepld via %s", unixaddr)
|
||||
conn, err := net.DialUnix("unix", nil, unixaddr)
|
||||
if err != nil {
|
||||
log.Printf("error connecting to zrepld: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Printf("sending stdin and stdout fds to zrepld")
|
||||
err = fd.Put(conn, os.Stdin, os.Stdout)
|
||||
if err != nil {
|
||||
log.Printf("error: %s", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Printf("waiting for zrepld to close control connection")
|
||||
for {
|
||||
|
||||
var buf [64]byte
|
||||
n, err := conn.Read(buf[:])
|
||||
if err == nil && n != 0 {
|
||||
log.Printf("protocol error: read expected to timeout or EOF returned bytes")
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
log.Printf("zrepld closed control connection, terminating")
|
||||
break
|
||||
}
|
||||
|
||||
neterr, ok := err.(net.Error)
|
||||
if !ok {
|
||||
log.Printf("received unexpected error type: %T %s", err, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if !neterr.Timeout() {
|
||||
log.Printf("receivd unexpected net.Error (not a timeout): %s", neterr)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Read timed out, as expected
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user