Multi-client servers + bring back stdinserver support

This commit is contained in:
Christian Schwarz 2018-09-04 16:41:54 -07:00
parent e161347e47
commit 308e5e35fb
12 changed files with 356 additions and 71 deletions

41
client/stdinserver.go Normal file
View File

@ -0,0 +1,41 @@
package client
import (
"os"
"context"
"github.com/problame/go-netssh"
"log"
"path"
"github.com/zrepl/zrepl/config"
"errors"
)
func RunStdinserver(config *config.Config, args []string) error {
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
defer os.Exit(1)
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
if len(args) != 1 || args[0] == "" {
err := errors.New("must specify client_identity as positional argument")
return err
}
identity := args[0]
unixaddr := path.Join(config.Global.Serve.StdinServer.SockDir, identity)
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
ctx := netssh.ContextWithLog(context.TODO(), log)
err := netssh.Proxy(ctx, unixaddr)
if err == nil {
log.Print("proxying finished successfully, exiting with status 0")
os.Exit(0)
}
log.Printf("error proxying: %s", err)
return nil
}

View File

@ -165,7 +165,7 @@ type SSHStdinserverConnect struct {
IdentityFile string `yaml:"identity_file"`
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
Options []string `yaml:"options"`
Options []string `yaml:"options,optional"`
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
}
@ -190,13 +190,13 @@ type TLSServe struct {
Ca string `yaml:"ca"`
Cert string `yaml:"cert"`
Key string `yaml:"key"`
ClientCN string `yaml:"client_cn"`
ClientCNs []string `yaml:"client_cns"`
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"`
}
type StdinserverServer struct {
ServeCommon `yaml:",inline"`
ClientIdentity string `yaml:"client_identity"`
ClientIdentities []string `yaml:"client_identities"`
}
type PruningEnum struct {

View File

@ -8,7 +8,9 @@ jobs:
ca: "ca.pem"
cert: "cert.pem"
key: "key.pem"
client_cn: "laptop1"
client_cns:
- "laptop1"
- "homeserver"
global:
logging:
- type: "tcp"

View File

@ -3,7 +3,9 @@ jobs:
type: source
serve:
type: stdinserver
client_identity: "client1"
client_identities:
- "client1"
- "client2"
filesystems: {
"<": true,
"secret": false

View File

@ -9,31 +9,27 @@ import (
"github.com/zrepl/zrepl/daemon/logging"
"github.com/zrepl/zrepl/daemon/serve"
"github.com/zrepl/zrepl/endpoint"
"net"
"path"
)
type Sink struct {
name string
l serve.ListenerFactory
rpcConf *streamrpc.ConnConfig
fsmap endpoint.FSMap
fsmapInv endpoint.FSFilter
rootDataset string
}
func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) {
// FIXME multi client support
s = &Sink{name: in.Name}
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil {
return nil, errors.Wrap(err, "cannot build server")
}
fsmap := filters.NewDatasetMapFilter(1, false) // FIXME multi-client support
if err := fsmap.Add("<", in.RootDataset); err != nil {
return nil, errors.Wrap(err, "unexpected error: cannot build filesystem mapping")
if in.RootDataset == "" {
return nil, errors.Wrap(err, "must specify root dataset")
}
s.fsmap = fsmap
s.rootDataset = in.RootDataset
return s, nil
}
@ -55,6 +51,7 @@ func (j *Sink) Run(ctx context.Context) {
log.WithError(err).Error("cannot listen")
return
}
defer l.Close()
log.WithField("addr", l.Addr()).Debug("accepting connections")
@ -64,10 +61,10 @@ outer:
for {
select {
case res := <-accept(l):
case res := <-accept(ctx, l):
if res.err != nil {
log.WithError(err).Info("accept error")
break outer
log.WithError(res.err).Info("accept error")
continue
}
connId++
connLog := log.
@ -82,14 +79,28 @@ outer:
}
func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
func (j *Sink) handleConnection(ctx context.Context, conn serve.AuthenticatedConn) {
defer conn.Close()
log := GetLogger(ctx)
log.WithField("addr", conn.RemoteAddr()).Info("handling connection")
log.
WithField("addr", conn.RemoteAddr()).
WithField("client_identity", conn.ClientIdentity()).
Info("handling connection")
defer log.Info("finished handling connection")
clientRoot := path.Join(j.rootDataset, conn.ClientIdentity())
log.WithField("client_root", clientRoot).Debug("client root")
fsmap := filters.NewDatasetMapFilter(1, false)
if err := fsmap.Add("<", clientRoot); err != nil {
log.WithError(err).
WithField("client_identity", conn.ClientIdentity()).
Error("cannot build client filesystem map (client identity must be a valid ZFS FS name")
}
ctx = logging.WithSubsystemLoggers(ctx, log)
local, err := endpoint.NewReceiver(j.fsmap, filters.NewAnyFSVFilter())
local, err := endpoint.NewReceiver(fsmap, filters.NewAnyFSVFilter())
if err != nil {
log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
return
@ -102,14 +113,14 @@ func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
}
type acceptResult struct {
conn net.Conn
conn serve.AuthenticatedConn
err error
}
func accept(listener net.Listener) <-chan acceptResult {
func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult {
c := make(chan acceptResult, 1)
go func() {
conn, err := listener.Accept()
conn, err := listener.Accept(ctx)
c <- acceptResult{conn, err}
}()
return c

View File

@ -15,6 +15,7 @@ import (
"github.com/zrepl/zrepl/tlsconf"
"os"
"github.com/zrepl/zrepl/daemon/snapper"
"github.com/zrepl/zrepl/daemon/serve"
)
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
@ -71,6 +72,7 @@ func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Contex
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint"))
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning"))
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot"))
ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve"))
return ctx
}

View File

@ -6,10 +6,69 @@ import (
"net"
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
"github.com/problame/go-streamrpc"
"context"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/zfs"
)
type contextKey int
const contextKeyLog contextKey = 0
type Logger = logger.Logger
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLog, log)
}
func getLogger(ctx context.Context) Logger {
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
return log
}
return logger.NewNullLogger()
}
type AuthenticatedConn interface {
net.Conn
// ClientIdentity must be a string that satisfies ValidateClientIdentity
ClientIdentity() string
}
// A client identity must be a single component in a ZFS filesystem path
func ValidateClientIdentity(in string) (err error) {
path, err := zfs.NewDatasetPath(in)
if err != nil {
return err
}
if path.Length() != 1 {
return errors.New("client identity must be a single path comonent (not empty, no '/')")
}
return nil
}
type authConn struct {
net.Conn
clientIdentity string
}
var _ AuthenticatedConn = authConn{}
func (c authConn) ClientIdentity() string {
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
panic(err)
}
return c.clientIdentity
}
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
type AuthenticatedListener interface {
Addr() (net.Addr)
Accept(ctx context.Context) (AuthenticatedConn, error)
Close() error
}
type ListenerFactory interface {
Listen() (net.Listener, error)
Listen() (AuthenticatedListener, error)
}
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
@ -25,7 +84,7 @@ func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf
lf, lfError = TLSListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.StdinserverServer:
lf, lfError = StdinserverListenerFactoryFromConfig(g, v)
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
default:
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)

View File

@ -8,54 +8,133 @@ import (
"net"
"path"
"time"
"context"
"github.com/pkg/errors"
"sync/atomic"
"fmt"
"os"
)
type StdinserverListenerFactory struct {
ClientIdentity string
sockpath string
ClientIdentities []string
Sockdir string
}
func StdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *StdinserverListenerFactory, err error) {
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
f = &StdinserverListenerFactory{
ClientIdentity: in.ClientIdentity,
for _, ci := range in.ClientIdentities {
if err := ValidateClientIdentity(ci); err != nil {
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
}
}
f.sockpath = path.Join(g.Serve.StdinServer.SockDir, f.ClientIdentity)
f = &multiStdinserverListenerFactory{
ClientIdentities: in.ClientIdentities,
Sockdir: g.Serve.StdinServer.SockDir,
}
return
}
func (f *StdinserverListenerFactory) Listen() (net.Listener, error) {
type multiStdinserverListenerFactory struct {
ClientIdentities []string
Sockdir string
}
if err := nethelpers.PreparePrivateSockpath(f.sockpath); err != nil {
return nil, err
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
}
type multiStdinserverAcceptRes struct {
conn AuthenticatedConn
err error
}
type MultiStdinserverListener struct {
listeners []*stdinserverListener
accepts chan multiStdinserverAcceptRes
closed int32
}
// client identities must be validated
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
listeners := make([]*stdinserverListener, 0, len(cis))
var err error
for _, ci := range cis {
sockpath := path.Join(sockdir, ci)
l := &stdinserverListener{clientIdentity: ci}
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
break
}
if l.l, err = netssh.Listen(sockpath); err != nil {
break
}
listeners = append(listeners, l)
}
l, err := netssh.Listen(f.sockpath)
if err != nil {
for _, l := range listeners {
l.Close() // FIXME error reporting?
}
return nil, err
}
return StdinserverListener{l}, nil
return &MultiStdinserverListener{listeners: listeners}, nil
}
type StdinserverListener struct {
l *netssh.Listener
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
if m.accepts == nil {
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
for i := range m.listeners {
go func(i int) {
for atomic.LoadInt32(&m.closed) == 0 {
fmt.Fprintf(os.Stderr, "accepting\n")
conn, err := m.listeners[i].Accept(context.TODO())
fmt.Fprintf(os.Stderr, "incoming\n")
m.accepts <- multiStdinserverAcceptRes{conn, err}
}
}(i)
}
}
res := <- m.accepts
return res.conn, res.err
}
func (l StdinserverListener) Addr() net.Addr {
func (m *MultiStdinserverListener) Addr() (net.Addr) {
return netsshAddr{}
}
func (l StdinserverListener) Accept() (net.Conn, error) {
func (m *MultiStdinserverListener) Close() error {
atomic.StoreInt32(&m.closed, 1)
var oneErr error
for _, l := range m.listeners {
if err := l.Close(); err != nil && oneErr == nil {
oneErr = err
}
}
return oneErr
}
// a single stdinserverListener (part of multiStinserverListener)
type stdinserverListener struct {
l *netssh.Listener
clientIdentity string
}
func (l stdinserverListener) Addr() net.Addr {
return netsshAddr{}
}
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
c, err := l.l.Accept()
if err != nil {
return nil, err
}
return netsshConnToNetConnAdatper{c}, nil
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
}
func (l StdinserverListener) Close() (err error) {
func (l stdinserverListener) Close() (err error) {
return l.l.Close()
}
@ -66,12 +145,16 @@ func (netsshAddr) String() string { return "???" }
type netsshConnToNetConnAdatper struct {
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
clientIdentity string
}
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
// FIXME log warning once!
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }

View File

@ -3,19 +3,89 @@ package serve
import (
"github.com/zrepl/zrepl/config"
"net"
"github.com/pkg/errors"
"context"
)
type TCPListenerFactory struct {
Address string
address *net.TCPAddr
clientMap *ipMap
}
type ipMapEntry struct {
ip net.IP
ident string
}
type ipMap struct {
entries []ipMapEntry
}
func ipMapFromConfig(clients map[string]string) (*ipMap, error) {
entries := make([]ipMapEntry, 0, len(clients))
for clientIPString, clientIdent := range clients {
clientIP := net.ParseIP(clientIPString)
if clientIP == nil {
return nil, errors.Errorf("cannot parse client IP %q", clientIPString)
}
if err := ValidateClientIdentity(clientIdent); err != nil {
return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString)
}
entries = append(entries, ipMapEntry{clientIP, clientIdent})
}
return &ipMap{entries: entries}, nil
}
func (m *ipMap) Get(ip net.IP) (string, error) {
for _, e := range m.entries {
if e.ip.Equal(ip) {
return e.ident, nil
}
}
return "", errors.Errorf("no identity mapping for client IP %s", ip)
}
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) {
addr, err := net.ResolveTCPAddr("tcp", in.Listen)
if err != nil {
return nil, errors.Wrap(err, "cannot parse listen address")
}
clientMap, err := ipMapFromConfig(in.Clients)
if err != nil {
return nil, errors.Wrap(err, "cannot parse client IP map")
}
lf := &TCPListenerFactory{
Address: in.Listen,
address: addr,
clientMap: clientMap,
}
return lf, nil
}
func (f *TCPListenerFactory) Listen() (net.Listener, error) {
return net.Listen("tcp", f.Address)
func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) {
l, err := net.ListenTCP("tcp", f.address)
if err != nil {
return nil, err
}
return &TCPAuthListener{l, f.clientMap}, nil
}
type TCPAuthListener struct {
*net.TCPListener
clientMap *ipMap
}
func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
nc, err := f.TCPListener.Accept()
if err != nil {
return nil, err
}
clientIP := nc.RemoteAddr().(*net.TCPAddr).IP
clientIdent, err := f.clientMap.Get(clientIP)
if err != nil {
getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
nc.Close()
return nil, err
}
return authConn{nc, clientIdent}, nil
}

View File

@ -8,13 +8,13 @@ import (
"github.com/zrepl/zrepl/tlsconf"
"net"
"time"
"context"
)
type TLSListenerFactory struct {
address string
clientCA *x509.CertPool
serverCert tls.Certificate
clientCommonName string
handshakeTimeout time.Duration
}
@ -23,12 +23,10 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
address: in.Listen,
}
if in.Ca == "" || in.Cert == "" || in.Key == "" || in.ClientCN == "" {
return nil, errors.New("fields 'ca', 'cert', 'key' and 'client_cn' must be specified")
if in.Ca == "" || in.Cert == "" || in.Key == "" {
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
}
lf.clientCommonName = in.ClientCN
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
if err != nil {
return nil, errors.Wrap(err, "cannot parse ca file")
@ -42,11 +40,25 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
return lf, nil
}
func (f *TLSListenerFactory) Listen() (net.Listener, error) {
func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) {
l, err := net.Listen("tcp", f.address)
if err != nil {
return nil, err
}
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.clientCommonName, f.handshakeTimeout)
return tl, nil
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout)
return tlsAuthListener{tl}, nil
}
type tlsAuthListener struct {
*tlsconf.ClientAuthListener
}
func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
c, cn, err := l.ClientAuthListener.Accept()
if err != nil {
return nil, err
}
return authConn{c, cn}, nil
}

13
main.go
View File

@ -57,6 +57,18 @@ var statusCmd = &cobra.Command{
},
}
var stdinserverCmd = &cobra.Command{
Use: "stdinserver CLIENT_IDENTITY",
Short: "start in stdinserver mode (from authorized_keys file)",
RunE: func(cmd *cobra.Command, args []string) error {
conf, err := config.ParseConfig(rootArgs.configFile)
if err != nil {
return err
}
return client.RunStdinserver(conf, args)
},
}
var rootArgs struct {
configFile string
}
@ -67,6 +79,7 @@ func init() {
rootCmd.AddCommand(daemonCmd)
rootCmd.AddCommand(wakeupCmd)
rootCmd.AddCommand(statusCmd)
rootCmd.AddCommand(stdinserverCmd)
}
func main() {

View File

@ -4,7 +4,6 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"time"
@ -24,13 +23,12 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) {
type ClientAuthListener struct {
l net.Listener
clientCommonName string
handshakeTimeout time.Duration
}
func NewClientAuthListener(
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
clientCommonName string, handshakeTimeout time.Duration) *ClientAuthListener {
handshakeTimeout time.Duration) *ClientAuthListener {
if ca == nil {
panic(ca)
@ -38,9 +36,6 @@ func NewClientAuthListener(
if serverCert.Certificate == nil || serverCert.PrivateKey == nil {
panic(serverCert)
}
if clientCommonName == "" {
panic(clientCommonName)
}
tlsConf := tls.Config{
Certificates: []tls.Certificate{serverCert},
@ -51,19 +46,18 @@ func NewClientAuthListener(
l = tls.NewListener(l, &tlsConf)
return &ClientAuthListener{
l,
clientCommonName,
handshakeTimeout,
}
}
func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
c, err = l.l.Accept()
if err != nil {
return nil, err
return nil, "", err
}
tlsConn, ok := c.(*tls.Conn)
if !ok {
return c, err
return c, "", err
}
var (
@ -83,14 +77,10 @@ func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
goto CloseAndErr
}
cn = peerCerts[0].Subject.CommonName
if cn != l.clientCommonName {
err = fmt.Errorf("client cert common name does not match client_identity: %q != %q", cn, l.clientCommonName)
goto CloseAndErr
}
return c, nil
return c, cn, nil
CloseAndErr:
c.Close()
return nil, err
return nil, "", err
}
func (l *ClientAuthListener) Addr() net.Addr {