From 308e5e35fb05ef4652f92009e200f8bac7b55eb3 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 4 Sep 2018 16:41:54 -0700 Subject: [PATCH] Multi-client servers + bring back stdinserver support --- client/stdinserver.go | 41 ++++++++++ config/config.go | 6 +- config/samples/sink.yml | 4 +- config/samples/source_ssh.yml | 4 +- daemon/job/sink.go | 47 +++++++----- daemon/logging/build_logging.go | 2 + daemon/serve/serve.go | 63 +++++++++++++++- daemon/serve/serve_stdinserver.go | 119 +++++++++++++++++++++++++----- daemon/serve/serve_tcp.go | 78 +++++++++++++++++++- daemon/serve/serve_tls.go | 28 +++++-- main.go | 13 ++++ tlsconf/tlsconf.go | 22 ++---- 12 files changed, 356 insertions(+), 71 deletions(-) create mode 100644 client/stdinserver.go diff --git a/client/stdinserver.go b/client/stdinserver.go new file mode 100644 index 0000000..9d47a25 --- /dev/null +++ b/client/stdinserver.go @@ -0,0 +1,41 @@ +package client + +import ( + "os" + + "context" + "github.com/problame/go-netssh" + "log" + "path" + "github.com/zrepl/zrepl/config" + "errors" +) + + +func RunStdinserver(config *config.Config, args []string) error { + + // NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong + defer os.Exit(1) + + log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime) + + if len(args) != 1 || args[0] == "" { + err := errors.New("must specify client_identity as positional argument") + return err + } + + identity := args[0] + unixaddr := path.Join(config.Global.Serve.StdinServer.SockDir, identity) + + log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr) + + ctx := netssh.ContextWithLog(context.TODO(), log) + + err := netssh.Proxy(ctx, unixaddr) + if err == nil { + log.Print("proxying finished successfully, exiting with status 0") + os.Exit(0) + } + log.Printf("error proxying: %s", err) + return nil +} diff --git a/config/config.go b/config/config.go index ab6f339..03fb662 100644 --- a/config/config.go +++ b/config/config.go @@ -165,7 +165,7 @@ type SSHStdinserverConnect struct { IdentityFile string `yaml:"identity_file"` TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused SSHCommand string `yaml:"ssh_command,optional"` //TODO unused - Options []string `yaml:"options"` + Options []string `yaml:"options,optional"` DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` } @@ -190,13 +190,13 @@ type TLSServe struct { Ca string `yaml:"ca"` Cert string `yaml:"cert"` Key string `yaml:"key"` - ClientCN string `yaml:"client_cn"` + ClientCNs []string `yaml:"client_cns"` HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"` } type StdinserverServer struct { ServeCommon `yaml:",inline"` - ClientIdentity string `yaml:"client_identity"` + ClientIdentities []string `yaml:"client_identities"` } type PruningEnum struct { diff --git a/config/samples/sink.yml b/config/samples/sink.yml index 7fbf716..b927041 100644 --- a/config/samples/sink.yml +++ b/config/samples/sink.yml @@ -8,7 +8,9 @@ jobs: ca: "ca.pem" cert: "cert.pem" key: "key.pem" - client_cn: "laptop1" + client_cns: + - "laptop1" + - "homeserver" global: logging: - type: "tcp" diff --git a/config/samples/source_ssh.yml b/config/samples/source_ssh.yml index c707f83..b1c034d 100644 --- a/config/samples/source_ssh.yml +++ b/config/samples/source_ssh.yml @@ -3,7 +3,9 @@ jobs: type: source serve: type: stdinserver - client_identity: "client1" + client_identities: + - "client1" + - "client2" filesystems: { "<": true, "secret": false diff --git a/daemon/job/sink.go b/daemon/job/sink.go index 7622167..e69a7fb 100644 --- a/daemon/job/sink.go +++ b/daemon/job/sink.go @@ -9,31 +9,27 @@ import ( "github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/daemon/serve" "github.com/zrepl/zrepl/endpoint" - "net" + "path" ) type Sink struct { name string l serve.ListenerFactory rpcConf *streamrpc.ConnConfig - fsmap endpoint.FSMap - fsmapInv endpoint.FSFilter + rootDataset string } func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) { - // FIXME multi client support - s = &Sink{name: in.Name} if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil { return nil, errors.Wrap(err, "cannot build server") } - fsmap := filters.NewDatasetMapFilter(1, false) // FIXME multi-client support - if err := fsmap.Add("<", in.RootDataset); err != nil { - return nil, errors.Wrap(err, "unexpected error: cannot build filesystem mapping") + if in.RootDataset == "" { + return nil, errors.Wrap(err, "must specify root dataset") } - s.fsmap = fsmap + s.rootDataset = in.RootDataset return s, nil } @@ -55,6 +51,7 @@ func (j *Sink) Run(ctx context.Context) { log.WithError(err).Error("cannot listen") return } + defer l.Close() log.WithField("addr", l.Addr()).Debug("accepting connections") @@ -64,10 +61,10 @@ outer: for { select { - case res := <-accept(l): + case res := <-accept(ctx, l): if res.err != nil { - log.WithError(err).Info("accept error") - break outer + log.WithError(res.err).Info("accept error") + continue } connId++ connLog := log. @@ -82,14 +79,28 @@ outer: } -func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) { +func (j *Sink) handleConnection(ctx context.Context, conn serve.AuthenticatedConn) { + defer conn.Close() + log := GetLogger(ctx) - log.WithField("addr", conn.RemoteAddr()).Info("handling connection") + log. + WithField("addr", conn.RemoteAddr()). + WithField("client_identity", conn.ClientIdentity()). + Info("handling connection") defer log.Info("finished handling connection") + clientRoot := path.Join(j.rootDataset, conn.ClientIdentity()) + log.WithField("client_root", clientRoot).Debug("client root") + fsmap := filters.NewDatasetMapFilter(1, false) + if err := fsmap.Add("<", clientRoot); err != nil { + log.WithError(err). + WithField("client_identity", conn.ClientIdentity()). + Error("cannot build client filesystem map (client identity must be a valid ZFS FS name") + } + ctx = logging.WithSubsystemLoggers(ctx, log) - local, err := endpoint.NewReceiver(j.fsmap, filters.NewAnyFSVFilter()) + local, err := endpoint.NewReceiver(fsmap, filters.NewAnyFSVFilter()) if err != nil { log.WithError(err).Error("unexpected error: cannot convert mapping to filter") return @@ -102,14 +113,14 @@ func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) { } type acceptResult struct { - conn net.Conn + conn serve.AuthenticatedConn err error } -func accept(listener net.Listener) <-chan acceptResult { +func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult { c := make(chan acceptResult, 1) go func() { - conn, err := listener.Accept() + conn, err := listener.Accept(ctx) c <- acceptResult{conn, err} }() return c diff --git a/daemon/logging/build_logging.go b/daemon/logging/build_logging.go index 56099bb..754fa2f 100644 --- a/daemon/logging/build_logging.go +++ b/daemon/logging/build_logging.go @@ -15,6 +15,7 @@ import ( "github.com/zrepl/zrepl/tlsconf" "os" "github.com/zrepl/zrepl/daemon/snapper" + "github.com/zrepl/zrepl/daemon/serve" ) func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) { @@ -71,6 +72,7 @@ func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Contex ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint")) ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning")) ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot")) + ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve")) return ctx } diff --git a/daemon/serve/serve.go b/daemon/serve/serve.go index fa7bb4a..8000a94 100644 --- a/daemon/serve/serve.go +++ b/daemon/serve/serve.go @@ -6,10 +6,69 @@ import ( "net" "github.com/zrepl/zrepl/daemon/streamrpcconfig" "github.com/problame/go-streamrpc" + "context" + "github.com/zrepl/zrepl/logger" + "github.com/zrepl/zrepl/zfs" ) +type contextKey int + +const contextKeyLog contextKey = 0 + +type Logger = logger.Logger + +func WithLogger(ctx context.Context, log Logger) context.Context { + return context.WithValue(ctx, contextKeyLog, log) +} + +func getLogger(ctx context.Context) Logger { + if log, ok := ctx.Value(contextKeyLog).(Logger); ok { + return log + } + return logger.NewNullLogger() +} + +type AuthenticatedConn interface { + net.Conn + // ClientIdentity must be a string that satisfies ValidateClientIdentity + ClientIdentity() string +} + +// A client identity must be a single component in a ZFS filesystem path +func ValidateClientIdentity(in string) (err error) { + path, err := zfs.NewDatasetPath(in) + if err != nil { + return err + } + if path.Length() != 1 { + return errors.New("client identity must be a single path comonent (not empty, no '/')") + } + return nil +} + +type authConn struct { + net.Conn + clientIdentity string +} + +var _ AuthenticatedConn = authConn{} + +func (c authConn) ClientIdentity() string { + if err := ValidateClientIdentity(c.clientIdentity); err != nil { + panic(err) + } + return c.clientIdentity +} + +// like net.Listener, but with an AuthenticatedConn instead of net.Conn +type AuthenticatedListener interface { + Addr() (net.Addr) + Accept(ctx context.Context) (AuthenticatedConn, error) + Close() error +} + type ListenerFactory interface { - Listen() (net.Listener, error) + Listen() (AuthenticatedListener, error) } func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) { @@ -25,7 +84,7 @@ func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf lf, lfError = TLSListenerFactoryFromConfig(g, v) conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) case *config.StdinserverServer: - lf, lfError = StdinserverListenerFactoryFromConfig(g, v) + lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v) conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) default: return nil, nil, errors.Errorf("internal error: unknown serve type %T", v) diff --git a/daemon/serve/serve_stdinserver.go b/daemon/serve/serve_stdinserver.go index f6403d3..baa8c88 100644 --- a/daemon/serve/serve_stdinserver.go +++ b/daemon/serve/serve_stdinserver.go @@ -8,54 +8,133 @@ import ( "net" "path" "time" + "context" + "github.com/pkg/errors" + "sync/atomic" + "fmt" + "os" ) type StdinserverListenerFactory struct { - ClientIdentity string - sockpath string + ClientIdentities []string + Sockdir string } -func StdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *StdinserverListenerFactory, err error) { +func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) { - f = &StdinserverListenerFactory{ - ClientIdentity: in.ClientIdentity, + for _, ci := range in.ClientIdentities { + if err := ValidateClientIdentity(ci); err != nil { + return nil, errors.Wrapf(err, "invalid client identity %q", ci) + } } - f.sockpath = path.Join(g.Serve.StdinServer.SockDir, f.ClientIdentity) + f = &multiStdinserverListenerFactory{ + ClientIdentities: in.ClientIdentities, + Sockdir: g.Serve.StdinServer.SockDir, + } return } -func (f *StdinserverListenerFactory) Listen() (net.Listener, error) { +type multiStdinserverListenerFactory struct { + ClientIdentities []string + Sockdir string +} - if err := nethelpers.PreparePrivateSockpath(f.sockpath); err != nil { - return nil, err +func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) { + return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities) +} + +type multiStdinserverAcceptRes struct { + conn AuthenticatedConn + err error +} + +type MultiStdinserverListener struct { + listeners []*stdinserverListener + accepts chan multiStdinserverAcceptRes + closed int32 +} + +// client identities must be validated +func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) { + listeners := make([]*stdinserverListener, 0, len(cis)) + var err error + for _, ci := range cis { + sockpath := path.Join(sockdir, ci) + l := &stdinserverListener{clientIdentity: ci} + if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil { + break + } + if l.l, err = netssh.Listen(sockpath); err != nil { + break + } + listeners = append(listeners, l) } - - l, err := netssh.Listen(f.sockpath) if err != nil { + for _, l := range listeners { + l.Close() // FIXME error reporting? + } return nil, err } - return StdinserverListener{l}, nil + return &MultiStdinserverListener{listeners: listeners}, nil } -type StdinserverListener struct { - l *netssh.Listener +func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){ + + if m.accepts == nil { + m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners)) + for i := range m.listeners { + go func(i int) { + for atomic.LoadInt32(&m.closed) == 0 { + fmt.Fprintf(os.Stderr, "accepting\n") + conn, err := m.listeners[i].Accept(context.TODO()) + fmt.Fprintf(os.Stderr, "incoming\n") + m.accepts <- multiStdinserverAcceptRes{conn, err} + } + }(i) + } + } + + res := <- m.accepts + return res.conn, res.err + } -func (l StdinserverListener) Addr() net.Addr { +func (m *MultiStdinserverListener) Addr() (net.Addr) { return netsshAddr{} } -func (l StdinserverListener) Accept() (net.Conn, error) { +func (m *MultiStdinserverListener) Close() error { + atomic.StoreInt32(&m.closed, 1) + var oneErr error + for _, l := range m.listeners { + if err := l.Close(); err != nil && oneErr == nil { + oneErr = err + } + } + return oneErr +} + +// a single stdinserverListener (part of multiStinserverListener) +type stdinserverListener struct { + l *netssh.Listener + clientIdentity string +} + +func (l stdinserverListener) Addr() net.Addr { + return netsshAddr{} +} + +func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) { c, err := l.l.Accept() if err != nil { return nil, err } - return netsshConnToNetConnAdatper{c}, nil + return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil } -func (l StdinserverListener) Close() (err error) { +func (l stdinserverListener) Close() (err error) { return l.l.Close() } @@ -66,12 +145,16 @@ func (netsshAddr) String() string { return "???" } type netsshConnToNetConnAdatper struct { io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn + clientIdentity string } +func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity } + func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} } func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} } +// FIXME log warning once! func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil } func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil } diff --git a/daemon/serve/serve_tcp.go b/daemon/serve/serve_tcp.go index 21cab59..957d3b9 100644 --- a/daemon/serve/serve_tcp.go +++ b/daemon/serve/serve_tcp.go @@ -3,19 +3,89 @@ package serve import ( "github.com/zrepl/zrepl/config" "net" + "github.com/pkg/errors" + "context" ) type TCPListenerFactory struct { - Address string + address *net.TCPAddr + clientMap *ipMap +} + +type ipMapEntry struct { + ip net.IP + ident string +} + +type ipMap struct { + entries []ipMapEntry +} + +func ipMapFromConfig(clients map[string]string) (*ipMap, error) { + entries := make([]ipMapEntry, 0, len(clients)) + for clientIPString, clientIdent := range clients { + clientIP := net.ParseIP(clientIPString) + if clientIP == nil { + return nil, errors.Errorf("cannot parse client IP %q", clientIPString) + } + if err := ValidateClientIdentity(clientIdent); err != nil { + return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString) + } + entries = append(entries, ipMapEntry{clientIP, clientIdent}) + } + return &ipMap{entries: entries}, nil +} + +func (m *ipMap) Get(ip net.IP) (string, error) { + for _, e := range m.entries { + if e.ip.Equal(ip) { + return e.ident, nil + } + } + return "", errors.Errorf("no identity mapping for client IP %s", ip) } func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) { + addr, err := net.ResolveTCPAddr("tcp", in.Listen) + if err != nil { + return nil, errors.Wrap(err, "cannot parse listen address") + } + clientMap, err := ipMapFromConfig(in.Clients) + if err != nil { + return nil, errors.Wrap(err, "cannot parse client IP map") + } lf := &TCPListenerFactory{ - Address: in.Listen, + address: addr, + clientMap: clientMap, } return lf, nil } -func (f *TCPListenerFactory) Listen() (net.Listener, error) { - return net.Listen("tcp", f.Address) +func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) { + l, err := net.ListenTCP("tcp", f.address) + if err != nil { + return nil, err + } + return &TCPAuthListener{l, f.clientMap}, nil } + +type TCPAuthListener struct { + *net.TCPListener + clientMap *ipMap +} + +func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { + nc, err := f.TCPListener.Accept() + if err != nil { + return nil, err + } + clientIP := nc.RemoteAddr().(*net.TCPAddr).IP + clientIdent, err := f.clientMap.Get(clientIP) + if err != nil { + getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map") + nc.Close() + return nil, err + } + return authConn{nc, clientIdent}, nil +} + diff --git a/daemon/serve/serve_tls.go b/daemon/serve/serve_tls.go index 0b80345..8f5e527 100644 --- a/daemon/serve/serve_tls.go +++ b/daemon/serve/serve_tls.go @@ -8,13 +8,13 @@ import ( "github.com/zrepl/zrepl/tlsconf" "net" "time" + "context" ) type TLSListenerFactory struct { address string clientCA *x509.CertPool serverCert tls.Certificate - clientCommonName string handshakeTimeout time.Duration } @@ -23,12 +23,10 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL address: in.Listen, } - if in.Ca == "" || in.Cert == "" || in.Key == "" || in.ClientCN == "" { - return nil, errors.New("fields 'ca', 'cert', 'key' and 'client_cn' must be specified") + if in.Ca == "" || in.Cert == "" || in.Key == "" { + return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified") } - lf.clientCommonName = in.ClientCN - lf.clientCA, err = tlsconf.ParseCAFile(in.Ca) if err != nil { return nil, errors.Wrap(err, "cannot parse ca file") @@ -42,11 +40,25 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL return lf, nil } -func (f *TLSListenerFactory) Listen() (net.Listener, error) { +func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) { l, err := net.Listen("tcp", f.address) if err != nil { return nil, err } - tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.clientCommonName, f.handshakeTimeout) - return tl, nil + tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout) + return tlsAuthListener{tl}, nil } + +type tlsAuthListener struct { + *tlsconf.ClientAuthListener +} + +func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { + c, cn, err := l.ClientAuthListener.Accept() + if err != nil { + return nil, err + } + return authConn{c, cn}, nil +} + + diff --git a/main.go b/main.go index 40f1f33..c520ce3 100644 --- a/main.go +++ b/main.go @@ -57,6 +57,18 @@ var statusCmd = &cobra.Command{ }, } +var stdinserverCmd = &cobra.Command{ + Use: "stdinserver CLIENT_IDENTITY", + Short: "start in stdinserver mode (from authorized_keys file)", + RunE: func(cmd *cobra.Command, args []string) error { + conf, err := config.ParseConfig(rootArgs.configFile) + if err != nil { + return err + } + return client.RunStdinserver(conf, args) + }, +} + var rootArgs struct { configFile string } @@ -67,6 +79,7 @@ func init() { rootCmd.AddCommand(daemonCmd) rootCmd.AddCommand(wakeupCmd) rootCmd.AddCommand(statusCmd) + rootCmd.AddCommand(stdinserverCmd) } func main() { diff --git a/tlsconf/tlsconf.go b/tlsconf/tlsconf.go index cf47968..48fc382 100644 --- a/tlsconf/tlsconf.go +++ b/tlsconf/tlsconf.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "crypto/x509" "errors" - "fmt" "io/ioutil" "net" "time" @@ -24,13 +23,12 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) { type ClientAuthListener struct { l net.Listener - clientCommonName string handshakeTimeout time.Duration } func NewClientAuthListener( l net.Listener, ca *x509.CertPool, serverCert tls.Certificate, - clientCommonName string, handshakeTimeout time.Duration) *ClientAuthListener { + handshakeTimeout time.Duration) *ClientAuthListener { if ca == nil { panic(ca) @@ -38,9 +36,6 @@ func NewClientAuthListener( if serverCert.Certificate == nil || serverCert.PrivateKey == nil { panic(serverCert) } - if clientCommonName == "" { - panic(clientCommonName) - } tlsConf := tls.Config{ Certificates: []tls.Certificate{serverCert}, @@ -51,19 +46,18 @@ func NewClientAuthListener( l = tls.NewListener(l, &tlsConf) return &ClientAuthListener{ l, - clientCommonName, handshakeTimeout, } } -func (l *ClientAuthListener) Accept() (c net.Conn, err error) { +func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) { c, err = l.l.Accept() if err != nil { - return nil, err + return nil, "", err } tlsConn, ok := c.(*tls.Conn) if !ok { - return c, err + return c, "", err } var ( @@ -83,14 +77,10 @@ func (l *ClientAuthListener) Accept() (c net.Conn, err error) { goto CloseAndErr } cn = peerCerts[0].Subject.CommonName - if cn != l.clientCommonName { - err = fmt.Errorf("client cert common name does not match client_identity: %q != %q", cn, l.clientCommonName) - goto CloseAndErr - } - return c, nil + return c, cn, nil CloseAndErr: c.Close() - return nil, err + return nil, "", err } func (l *ClientAuthListener) Addr() net.Addr {