mirror of
https://github.com/zrepl/zrepl.git
synced 2025-06-19 00:07:10 +02:00
Multi-client servers + bring back stdinserver support
This commit is contained in:
parent
e161347e47
commit
308e5e35fb
41
client/stdinserver.go
Normal file
41
client/stdinserver.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"context"
|
||||||
|
"github.com/problame/go-netssh"
|
||||||
|
"log"
|
||||||
|
"path"
|
||||||
|
"github.com/zrepl/zrepl/config"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
func RunStdinserver(config *config.Config, args []string) error {
|
||||||
|
|
||||||
|
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
|
||||||
|
defer os.Exit(1)
|
||||||
|
|
||||||
|
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||||
|
|
||||||
|
if len(args) != 1 || args[0] == "" {
|
||||||
|
err := errors.New("must specify client_identity as positional argument")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
identity := args[0]
|
||||||
|
unixaddr := path.Join(config.Global.Serve.StdinServer.SockDir, identity)
|
||||||
|
|
||||||
|
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
|
||||||
|
|
||||||
|
ctx := netssh.ContextWithLog(context.TODO(), log)
|
||||||
|
|
||||||
|
err := netssh.Proxy(ctx, unixaddr)
|
||||||
|
if err == nil {
|
||||||
|
log.Print("proxying finished successfully, exiting with status 0")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
log.Printf("error proxying: %s", err)
|
||||||
|
return nil
|
||||||
|
}
|
@ -165,7 +165,7 @@ type SSHStdinserverConnect struct {
|
|||||||
IdentityFile string `yaml:"identity_file"`
|
IdentityFile string `yaml:"identity_file"`
|
||||||
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
|
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
|
||||||
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
|
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
|
||||||
Options []string `yaml:"options"`
|
Options []string `yaml:"options,optional"`
|
||||||
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
|
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,13 +190,13 @@ type TLSServe struct {
|
|||||||
Ca string `yaml:"ca"`
|
Ca string `yaml:"ca"`
|
||||||
Cert string `yaml:"cert"`
|
Cert string `yaml:"cert"`
|
||||||
Key string `yaml:"key"`
|
Key string `yaml:"key"`
|
||||||
ClientCN string `yaml:"client_cn"`
|
ClientCNs []string `yaml:"client_cns"`
|
||||||
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"`
|
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StdinserverServer struct {
|
type StdinserverServer struct {
|
||||||
ServeCommon `yaml:",inline"`
|
ServeCommon `yaml:",inline"`
|
||||||
ClientIdentity string `yaml:"client_identity"`
|
ClientIdentities []string `yaml:"client_identities"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type PruningEnum struct {
|
type PruningEnum struct {
|
||||||
|
@ -8,7 +8,9 @@ jobs:
|
|||||||
ca: "ca.pem"
|
ca: "ca.pem"
|
||||||
cert: "cert.pem"
|
cert: "cert.pem"
|
||||||
key: "key.pem"
|
key: "key.pem"
|
||||||
client_cn: "laptop1"
|
client_cns:
|
||||||
|
- "laptop1"
|
||||||
|
- "homeserver"
|
||||||
global:
|
global:
|
||||||
logging:
|
logging:
|
||||||
- type: "tcp"
|
- type: "tcp"
|
||||||
|
@ -3,7 +3,9 @@ jobs:
|
|||||||
type: source
|
type: source
|
||||||
serve:
|
serve:
|
||||||
type: stdinserver
|
type: stdinserver
|
||||||
client_identity: "client1"
|
client_identities:
|
||||||
|
- "client1"
|
||||||
|
- "client2"
|
||||||
filesystems: {
|
filesystems: {
|
||||||
"<": true,
|
"<": true,
|
||||||
"secret": false
|
"secret": false
|
||||||
|
@ -9,31 +9,27 @@ import (
|
|||||||
"github.com/zrepl/zrepl/daemon/logging"
|
"github.com/zrepl/zrepl/daemon/logging"
|
||||||
"github.com/zrepl/zrepl/daemon/serve"
|
"github.com/zrepl/zrepl/daemon/serve"
|
||||||
"github.com/zrepl/zrepl/endpoint"
|
"github.com/zrepl/zrepl/endpoint"
|
||||||
"net"
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Sink struct {
|
type Sink struct {
|
||||||
name string
|
name string
|
||||||
l serve.ListenerFactory
|
l serve.ListenerFactory
|
||||||
rpcConf *streamrpc.ConnConfig
|
rpcConf *streamrpc.ConnConfig
|
||||||
fsmap endpoint.FSMap
|
rootDataset string
|
||||||
fsmapInv endpoint.FSFilter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) {
|
func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) {
|
||||||
|
|
||||||
// FIXME multi client support
|
|
||||||
|
|
||||||
s = &Sink{name: in.Name}
|
s = &Sink{name: in.Name}
|
||||||
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil {
|
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot build server")
|
return nil, errors.Wrap(err, "cannot build server")
|
||||||
}
|
}
|
||||||
|
|
||||||
fsmap := filters.NewDatasetMapFilter(1, false) // FIXME multi-client support
|
if in.RootDataset == "" {
|
||||||
if err := fsmap.Add("<", in.RootDataset); err != nil {
|
return nil, errors.Wrap(err, "must specify root dataset")
|
||||||
return nil, errors.Wrap(err, "unexpected error: cannot build filesystem mapping")
|
|
||||||
}
|
}
|
||||||
s.fsmap = fsmap
|
s.rootDataset = in.RootDataset
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
@ -55,6 +51,7 @@ func (j *Sink) Run(ctx context.Context) {
|
|||||||
log.WithError(err).Error("cannot listen")
|
log.WithError(err).Error("cannot listen")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
log.WithField("addr", l.Addr()).Debug("accepting connections")
|
log.WithField("addr", l.Addr()).Debug("accepting connections")
|
||||||
|
|
||||||
@ -64,10 +61,10 @@ outer:
|
|||||||
for {
|
for {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case res := <-accept(l):
|
case res := <-accept(ctx, l):
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
log.WithError(err).Info("accept error")
|
log.WithError(res.err).Info("accept error")
|
||||||
break outer
|
continue
|
||||||
}
|
}
|
||||||
connId++
|
connId++
|
||||||
connLog := log.
|
connLog := log.
|
||||||
@ -82,14 +79,28 @@ outer:
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
|
func (j *Sink) handleConnection(ctx context.Context, conn serve.AuthenticatedConn) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
log := GetLogger(ctx)
|
log := GetLogger(ctx)
|
||||||
log.WithField("addr", conn.RemoteAddr()).Info("handling connection")
|
log.
|
||||||
|
WithField("addr", conn.RemoteAddr()).
|
||||||
|
WithField("client_identity", conn.ClientIdentity()).
|
||||||
|
Info("handling connection")
|
||||||
defer log.Info("finished handling connection")
|
defer log.Info("finished handling connection")
|
||||||
|
|
||||||
|
clientRoot := path.Join(j.rootDataset, conn.ClientIdentity())
|
||||||
|
log.WithField("client_root", clientRoot).Debug("client root")
|
||||||
|
fsmap := filters.NewDatasetMapFilter(1, false)
|
||||||
|
if err := fsmap.Add("<", clientRoot); err != nil {
|
||||||
|
log.WithError(err).
|
||||||
|
WithField("client_identity", conn.ClientIdentity()).
|
||||||
|
Error("cannot build client filesystem map (client identity must be a valid ZFS FS name")
|
||||||
|
}
|
||||||
|
|
||||||
ctx = logging.WithSubsystemLoggers(ctx, log)
|
ctx = logging.WithSubsystemLoggers(ctx, log)
|
||||||
|
|
||||||
local, err := endpoint.NewReceiver(j.fsmap, filters.NewAnyFSVFilter())
|
local, err := endpoint.NewReceiver(fsmap, filters.NewAnyFSVFilter())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
|
log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
|
||||||
return
|
return
|
||||||
@ -102,14 +113,14 @@ func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type acceptResult struct {
|
type acceptResult struct {
|
||||||
conn net.Conn
|
conn serve.AuthenticatedConn
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func accept(listener net.Listener) <-chan acceptResult {
|
func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult {
|
||||||
c := make(chan acceptResult, 1)
|
c := make(chan acceptResult, 1)
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept(ctx)
|
||||||
c <- acceptResult{conn, err}
|
c <- acceptResult{conn, err}
|
||||||
}()
|
}()
|
||||||
return c
|
return c
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/zrepl/zrepl/tlsconf"
|
"github.com/zrepl/zrepl/tlsconf"
|
||||||
"os"
|
"os"
|
||||||
"github.com/zrepl/zrepl/daemon/snapper"
|
"github.com/zrepl/zrepl/daemon/snapper"
|
||||||
|
"github.com/zrepl/zrepl/daemon/serve"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
|
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
|
||||||
@ -71,6 +72,7 @@ func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Contex
|
|||||||
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint"))
|
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint"))
|
||||||
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning"))
|
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning"))
|
||||||
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot"))
|
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot"))
|
||||||
|
ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve"))
|
||||||
return ctx
|
return ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,10 +6,69 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
||||||
"github.com/problame/go-streamrpc"
|
"github.com/problame/go-streamrpc"
|
||||||
|
"context"
|
||||||
|
"github.com/zrepl/zrepl/logger"
|
||||||
|
"github.com/zrepl/zrepl/zfs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const contextKeyLog contextKey = 0
|
||||||
|
|
||||||
|
type Logger = logger.Logger
|
||||||
|
|
||||||
|
func WithLogger(ctx context.Context, log Logger) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyLog, log)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLogger(ctx context.Context) Logger {
|
||||||
|
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
||||||
|
return log
|
||||||
|
}
|
||||||
|
return logger.NewNullLogger()
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthenticatedConn interface {
|
||||||
|
net.Conn
|
||||||
|
// ClientIdentity must be a string that satisfies ValidateClientIdentity
|
||||||
|
ClientIdentity() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// A client identity must be a single component in a ZFS filesystem path
|
||||||
|
func ValidateClientIdentity(in string) (err error) {
|
||||||
|
path, err := zfs.NewDatasetPath(in)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if path.Length() != 1 {
|
||||||
|
return errors.New("client identity must be a single path comonent (not empty, no '/')")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type authConn struct {
|
||||||
|
net.Conn
|
||||||
|
clientIdentity string
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ AuthenticatedConn = authConn{}
|
||||||
|
|
||||||
|
func (c authConn) ClientIdentity() string {
|
||||||
|
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return c.clientIdentity
|
||||||
|
}
|
||||||
|
|
||||||
|
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
|
||||||
|
type AuthenticatedListener interface {
|
||||||
|
Addr() (net.Addr)
|
||||||
|
Accept(ctx context.Context) (AuthenticatedConn, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
type ListenerFactory interface {
|
type ListenerFactory interface {
|
||||||
Listen() (net.Listener, error)
|
Listen() (AuthenticatedListener, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
||||||
@ -25,7 +84,7 @@ func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf
|
|||||||
lf, lfError = TLSListenerFactoryFromConfig(g, v)
|
lf, lfError = TLSListenerFactoryFromConfig(g, v)
|
||||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||||
case *config.StdinserverServer:
|
case *config.StdinserverServer:
|
||||||
lf, lfError = StdinserverListenerFactoryFromConfig(g, v)
|
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
|
||||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||||
default:
|
default:
|
||||||
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
|
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
|
||||||
|
@ -8,54 +8,133 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"path"
|
"path"
|
||||||
"time"
|
"time"
|
||||||
|
"context"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type StdinserverListenerFactory struct {
|
type StdinserverListenerFactory struct {
|
||||||
ClientIdentity string
|
ClientIdentities []string
|
||||||
sockpath string
|
Sockdir string
|
||||||
}
|
}
|
||||||
|
|
||||||
func StdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *StdinserverListenerFactory, err error) {
|
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
|
||||||
|
|
||||||
f = &StdinserverListenerFactory{
|
for _, ci := range in.ClientIdentities {
|
||||||
ClientIdentity: in.ClientIdentity,
|
if err := ValidateClientIdentity(ci); err != nil {
|
||||||
|
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
f.sockpath = path.Join(g.Serve.StdinServer.SockDir, f.ClientIdentity)
|
f = &multiStdinserverListenerFactory{
|
||||||
|
ClientIdentities: in.ClientIdentities,
|
||||||
|
Sockdir: g.Serve.StdinServer.SockDir,
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *StdinserverListenerFactory) Listen() (net.Listener, error) {
|
type multiStdinserverListenerFactory struct {
|
||||||
|
ClientIdentities []string
|
||||||
if err := nethelpers.PreparePrivateSockpath(f.sockpath); err != nil {
|
Sockdir string
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
l, err := netssh.Listen(f.sockpath)
|
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||||
|
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiStdinserverAcceptRes struct {
|
||||||
|
conn AuthenticatedConn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type MultiStdinserverListener struct {
|
||||||
|
listeners []*stdinserverListener
|
||||||
|
accepts chan multiStdinserverAcceptRes
|
||||||
|
closed int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// client identities must be validated
|
||||||
|
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
|
||||||
|
listeners := make([]*stdinserverListener, 0, len(cis))
|
||||||
|
var err error
|
||||||
|
for _, ci := range cis {
|
||||||
|
sockpath := path.Join(sockdir, ci)
|
||||||
|
l := &stdinserverListener{clientIdentity: ci}
|
||||||
|
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if l.l, err = netssh.Listen(sockpath); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
listeners = append(listeners, l)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
for _, l := range listeners {
|
||||||
|
l.Close() // FIXME error reporting?
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return StdinserverListener{l}, nil
|
return &MultiStdinserverListener{listeners: listeners}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type StdinserverListener struct {
|
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
|
||||||
l *netssh.Listener
|
|
||||||
|
if m.accepts == nil {
|
||||||
|
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
|
||||||
|
for i := range m.listeners {
|
||||||
|
go func(i int) {
|
||||||
|
for atomic.LoadInt32(&m.closed) == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "accepting\n")
|
||||||
|
conn, err := m.listeners[i].Accept(context.TODO())
|
||||||
|
fmt.Fprintf(os.Stderr, "incoming\n")
|
||||||
|
m.accepts <- multiStdinserverAcceptRes{conn, err}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l StdinserverListener) Addr() net.Addr {
|
res := <- m.accepts
|
||||||
|
return res.conn, res.err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MultiStdinserverListener) Addr() (net.Addr) {
|
||||||
return netsshAddr{}
|
return netsshAddr{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l StdinserverListener) Accept() (net.Conn, error) {
|
func (m *MultiStdinserverListener) Close() error {
|
||||||
|
atomic.StoreInt32(&m.closed, 1)
|
||||||
|
var oneErr error
|
||||||
|
for _, l := range m.listeners {
|
||||||
|
if err := l.Close(); err != nil && oneErr == nil {
|
||||||
|
oneErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return oneErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// a single stdinserverListener (part of multiStinserverListener)
|
||||||
|
type stdinserverListener struct {
|
||||||
|
l *netssh.Listener
|
||||||
|
clientIdentity string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l stdinserverListener) Addr() net.Addr {
|
||||||
|
return netsshAddr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||||
c, err := l.l.Accept()
|
c, err := l.l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return netsshConnToNetConnAdatper{c}, nil
|
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l StdinserverListener) Close() (err error) {
|
func (l stdinserverListener) Close() (err error) {
|
||||||
return l.l.Close()
|
return l.l.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,12 +145,16 @@ func (netsshAddr) String() string { return "???" }
|
|||||||
|
|
||||||
type netsshConnToNetConnAdatper struct {
|
type netsshConnToNetConnAdatper struct {
|
||||||
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
|
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
|
||||||
|
clientIdentity string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
|
||||||
|
|
||||||
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
|
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
|
||||||
|
|
||||||
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
|
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
|
||||||
|
|
||||||
|
// FIXME log warning once!
|
||||||
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
|
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
|
||||||
|
|
||||||
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
|
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
@ -3,19 +3,89 @@ package serve
|
|||||||
import (
|
import (
|
||||||
"github.com/zrepl/zrepl/config"
|
"github.com/zrepl/zrepl/config"
|
||||||
"net"
|
"net"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TCPListenerFactory struct {
|
type TCPListenerFactory struct {
|
||||||
Address string
|
address *net.TCPAddr
|
||||||
|
clientMap *ipMap
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipMapEntry struct {
|
||||||
|
ip net.IP
|
||||||
|
ident string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipMap struct {
|
||||||
|
entries []ipMapEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
func ipMapFromConfig(clients map[string]string) (*ipMap, error) {
|
||||||
|
entries := make([]ipMapEntry, 0, len(clients))
|
||||||
|
for clientIPString, clientIdent := range clients {
|
||||||
|
clientIP := net.ParseIP(clientIPString)
|
||||||
|
if clientIP == nil {
|
||||||
|
return nil, errors.Errorf("cannot parse client IP %q", clientIPString)
|
||||||
|
}
|
||||||
|
if err := ValidateClientIdentity(clientIdent); err != nil {
|
||||||
|
return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString)
|
||||||
|
}
|
||||||
|
entries = append(entries, ipMapEntry{clientIP, clientIdent})
|
||||||
|
}
|
||||||
|
return &ipMap{entries: entries}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ipMap) Get(ip net.IP) (string, error) {
|
||||||
|
for _, e := range m.entries {
|
||||||
|
if e.ip.Equal(ip) {
|
||||||
|
return e.ident, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", errors.Errorf("no identity mapping for client IP %s", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) {
|
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) {
|
||||||
|
addr, err := net.ResolveTCPAddr("tcp", in.Listen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "cannot parse listen address")
|
||||||
|
}
|
||||||
|
clientMap, err := ipMapFromConfig(in.Clients)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "cannot parse client IP map")
|
||||||
|
}
|
||||||
lf := &TCPListenerFactory{
|
lf := &TCPListenerFactory{
|
||||||
Address: in.Listen,
|
address: addr,
|
||||||
|
clientMap: clientMap,
|
||||||
}
|
}
|
||||||
return lf, nil
|
return lf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *TCPListenerFactory) Listen() (net.Listener, error) {
|
func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||||
return net.Listen("tcp", f.Address)
|
l, err := net.ListenTCP("tcp", f.address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return &TCPAuthListener{l, f.clientMap}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPAuthListener struct {
|
||||||
|
*net.TCPListener
|
||||||
|
clientMap *ipMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||||
|
nc, err := f.TCPListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clientIP := nc.RemoteAddr().(*net.TCPAddr).IP
|
||||||
|
clientIdent, err := f.clientMap.Get(clientIP)
|
||||||
|
if err != nil {
|
||||||
|
getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
|
||||||
|
nc.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authConn{nc, clientIdent}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -8,13 +8,13 @@ import (
|
|||||||
"github.com/zrepl/zrepl/tlsconf"
|
"github.com/zrepl/zrepl/tlsconf"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TLSListenerFactory struct {
|
type TLSListenerFactory struct {
|
||||||
address string
|
address string
|
||||||
clientCA *x509.CertPool
|
clientCA *x509.CertPool
|
||||||
serverCert tls.Certificate
|
serverCert tls.Certificate
|
||||||
clientCommonName string
|
|
||||||
handshakeTimeout time.Duration
|
handshakeTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -23,12 +23,10 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
|
|||||||
address: in.Listen,
|
address: in.Listen,
|
||||||
}
|
}
|
||||||
|
|
||||||
if in.Ca == "" || in.Cert == "" || in.Key == "" || in.ClientCN == "" {
|
if in.Ca == "" || in.Cert == "" || in.Key == "" {
|
||||||
return nil, errors.New("fields 'ca', 'cert', 'key' and 'client_cn' must be specified")
|
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
lf.clientCommonName = in.ClientCN
|
|
||||||
|
|
||||||
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
|
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "cannot parse ca file")
|
return nil, errors.Wrap(err, "cannot parse ca file")
|
||||||
@ -42,11 +40,25 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
|
|||||||
return lf, nil
|
return lf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *TLSListenerFactory) Listen() (net.Listener, error) {
|
func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||||
l, err := net.Listen("tcp", f.address)
|
l, err := net.Listen("tcp", f.address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.clientCommonName, f.handshakeTimeout)
|
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout)
|
||||||
return tl, nil
|
return tlsAuthListener{tl}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type tlsAuthListener struct {
|
||||||
|
*tlsconf.ClientAuthListener
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||||
|
c, cn, err := l.ClientAuthListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authConn{c, cn}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
13
main.go
13
main.go
@ -57,6 +57,18 @@ var statusCmd = &cobra.Command{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var stdinserverCmd = &cobra.Command{
|
||||||
|
Use: "stdinserver CLIENT_IDENTITY",
|
||||||
|
Short: "start in stdinserver mode (from authorized_keys file)",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
conf, err := config.ParseConfig(rootArgs.configFile)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return client.RunStdinserver(conf, args)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
var rootArgs struct {
|
var rootArgs struct {
|
||||||
configFile string
|
configFile string
|
||||||
}
|
}
|
||||||
@ -67,6 +79,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(daemonCmd)
|
rootCmd.AddCommand(daemonCmd)
|
||||||
rootCmd.AddCommand(wakeupCmd)
|
rootCmd.AddCommand(wakeupCmd)
|
||||||
rootCmd.AddCommand(statusCmd)
|
rootCmd.AddCommand(statusCmd)
|
||||||
|
rootCmd.AddCommand(stdinserverCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -4,7 +4,6 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
@ -24,13 +23,12 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) {
|
|||||||
|
|
||||||
type ClientAuthListener struct {
|
type ClientAuthListener struct {
|
||||||
l net.Listener
|
l net.Listener
|
||||||
clientCommonName string
|
|
||||||
handshakeTimeout time.Duration
|
handshakeTimeout time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientAuthListener(
|
func NewClientAuthListener(
|
||||||
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
|
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
|
||||||
clientCommonName string, handshakeTimeout time.Duration) *ClientAuthListener {
|
handshakeTimeout time.Duration) *ClientAuthListener {
|
||||||
|
|
||||||
if ca == nil {
|
if ca == nil {
|
||||||
panic(ca)
|
panic(ca)
|
||||||
@ -38,9 +36,6 @@ func NewClientAuthListener(
|
|||||||
if serverCert.Certificate == nil || serverCert.PrivateKey == nil {
|
if serverCert.Certificate == nil || serverCert.PrivateKey == nil {
|
||||||
panic(serverCert)
|
panic(serverCert)
|
||||||
}
|
}
|
||||||
if clientCommonName == "" {
|
|
||||||
panic(clientCommonName)
|
|
||||||
}
|
|
||||||
|
|
||||||
tlsConf := tls.Config{
|
tlsConf := tls.Config{
|
||||||
Certificates: []tls.Certificate{serverCert},
|
Certificates: []tls.Certificate{serverCert},
|
||||||
@ -51,19 +46,18 @@ func NewClientAuthListener(
|
|||||||
l = tls.NewListener(l, &tlsConf)
|
l = tls.NewListener(l, &tlsConf)
|
||||||
return &ClientAuthListener{
|
return &ClientAuthListener{
|
||||||
l,
|
l,
|
||||||
clientCommonName,
|
|
||||||
handshakeTimeout,
|
handshakeTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
|
func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
|
||||||
c, err = l.l.Accept()
|
c, err = l.l.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
tlsConn, ok := c.(*tls.Conn)
|
tlsConn, ok := c.(*tls.Conn)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c, err
|
return c, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -83,14 +77,10 @@ func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
|
|||||||
goto CloseAndErr
|
goto CloseAndErr
|
||||||
}
|
}
|
||||||
cn = peerCerts[0].Subject.CommonName
|
cn = peerCerts[0].Subject.CommonName
|
||||||
if cn != l.clientCommonName {
|
return c, cn, nil
|
||||||
err = fmt.Errorf("client cert common name does not match client_identity: %q != %q", cn, l.clientCommonName)
|
|
||||||
goto CloseAndErr
|
|
||||||
}
|
|
||||||
return c, nil
|
|
||||||
CloseAndErr:
|
CloseAndErr:
|
||||||
c.Close()
|
c.Close()
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *ClientAuthListener) Addr() net.Addr {
|
func (l *ClientAuthListener) Addr() net.Addr {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user