mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-21 16:03:32 +01:00
Multi-client servers + bring back stdinserver support
This commit is contained in:
parent
e161347e47
commit
308e5e35fb
41
client/stdinserver.go
Normal file
41
client/stdinserver.go
Normal file
@ -0,0 +1,41 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"context"
|
||||
"github.com/problame/go-netssh"
|
||||
"log"
|
||||
"path"
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"errors"
|
||||
)
|
||||
|
||||
|
||||
func RunStdinserver(config *config.Config, args []string) error {
|
||||
|
||||
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
|
||||
defer os.Exit(1)
|
||||
|
||||
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
|
||||
|
||||
if len(args) != 1 || args[0] == "" {
|
||||
err := errors.New("must specify client_identity as positional argument")
|
||||
return err
|
||||
}
|
||||
|
||||
identity := args[0]
|
||||
unixaddr := path.Join(config.Global.Serve.StdinServer.SockDir, identity)
|
||||
|
||||
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
|
||||
|
||||
ctx := netssh.ContextWithLog(context.TODO(), log)
|
||||
|
||||
err := netssh.Proxy(ctx, unixaddr)
|
||||
if err == nil {
|
||||
log.Print("proxying finished successfully, exiting with status 0")
|
||||
os.Exit(0)
|
||||
}
|
||||
log.Printf("error proxying: %s", err)
|
||||
return nil
|
||||
}
|
@ -165,7 +165,7 @@ type SSHStdinserverConnect struct {
|
||||
IdentityFile string `yaml:"identity_file"`
|
||||
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
|
||||
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
|
||||
Options []string `yaml:"options"`
|
||||
Options []string `yaml:"options,optional"`
|
||||
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
|
||||
}
|
||||
|
||||
@ -190,13 +190,13 @@ type TLSServe struct {
|
||||
Ca string `yaml:"ca"`
|
||||
Cert string `yaml:"cert"`
|
||||
Key string `yaml:"key"`
|
||||
ClientCN string `yaml:"client_cn"`
|
||||
ClientCNs []string `yaml:"client_cns"`
|
||||
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"`
|
||||
}
|
||||
|
||||
type StdinserverServer struct {
|
||||
ServeCommon `yaml:",inline"`
|
||||
ClientIdentity string `yaml:"client_identity"`
|
||||
ClientIdentities []string `yaml:"client_identities"`
|
||||
}
|
||||
|
||||
type PruningEnum struct {
|
||||
|
@ -8,7 +8,9 @@ jobs:
|
||||
ca: "ca.pem"
|
||||
cert: "cert.pem"
|
||||
key: "key.pem"
|
||||
client_cn: "laptop1"
|
||||
client_cns:
|
||||
- "laptop1"
|
||||
- "homeserver"
|
||||
global:
|
||||
logging:
|
||||
- type: "tcp"
|
||||
|
@ -3,7 +3,9 @@ jobs:
|
||||
type: source
|
||||
serve:
|
||||
type: stdinserver
|
||||
client_identity: "client1"
|
||||
client_identities:
|
||||
- "client1"
|
||||
- "client2"
|
||||
filesystems: {
|
||||
"<": true,
|
||||
"secret": false
|
||||
|
@ -9,31 +9,27 @@ import (
|
||||
"github.com/zrepl/zrepl/daemon/logging"
|
||||
"github.com/zrepl/zrepl/daemon/serve"
|
||||
"github.com/zrepl/zrepl/endpoint"
|
||||
"net"
|
||||
"path"
|
||||
)
|
||||
|
||||
type Sink struct {
|
||||
name string
|
||||
l serve.ListenerFactory
|
||||
rpcConf *streamrpc.ConnConfig
|
||||
fsmap endpoint.FSMap
|
||||
fsmapInv endpoint.FSFilter
|
||||
rootDataset string
|
||||
}
|
||||
|
||||
func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) {
|
||||
|
||||
// FIXME multi client support
|
||||
|
||||
s = &Sink{name: in.Name}
|
||||
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil {
|
||||
return nil, errors.Wrap(err, "cannot build server")
|
||||
}
|
||||
|
||||
fsmap := filters.NewDatasetMapFilter(1, false) // FIXME multi-client support
|
||||
if err := fsmap.Add("<", in.RootDataset); err != nil {
|
||||
return nil, errors.Wrap(err, "unexpected error: cannot build filesystem mapping")
|
||||
if in.RootDataset == "" {
|
||||
return nil, errors.Wrap(err, "must specify root dataset")
|
||||
}
|
||||
s.fsmap = fsmap
|
||||
s.rootDataset = in.RootDataset
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@ -55,6 +51,7 @@ func (j *Sink) Run(ctx context.Context) {
|
||||
log.WithError(err).Error("cannot listen")
|
||||
return
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
log.WithField("addr", l.Addr()).Debug("accepting connections")
|
||||
|
||||
@ -64,10 +61,10 @@ outer:
|
||||
for {
|
||||
|
||||
select {
|
||||
case res := <-accept(l):
|
||||
case res := <-accept(ctx, l):
|
||||
if res.err != nil {
|
||||
log.WithError(err).Info("accept error")
|
||||
break outer
|
||||
log.WithError(res.err).Info("accept error")
|
||||
continue
|
||||
}
|
||||
connId++
|
||||
connLog := log.
|
||||
@ -82,14 +79,28 @@ outer:
|
||||
|
||||
}
|
||||
|
||||
func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
func (j *Sink) handleConnection(ctx context.Context, conn serve.AuthenticatedConn) {
|
||||
defer conn.Close()
|
||||
|
||||
log := GetLogger(ctx)
|
||||
log.WithField("addr", conn.RemoteAddr()).Info("handling connection")
|
||||
log.
|
||||
WithField("addr", conn.RemoteAddr()).
|
||||
WithField("client_identity", conn.ClientIdentity()).
|
||||
Info("handling connection")
|
||||
defer log.Info("finished handling connection")
|
||||
|
||||
clientRoot := path.Join(j.rootDataset, conn.ClientIdentity())
|
||||
log.WithField("client_root", clientRoot).Debug("client root")
|
||||
fsmap := filters.NewDatasetMapFilter(1, false)
|
||||
if err := fsmap.Add("<", clientRoot); err != nil {
|
||||
log.WithError(err).
|
||||
WithField("client_identity", conn.ClientIdentity()).
|
||||
Error("cannot build client filesystem map (client identity must be a valid ZFS FS name")
|
||||
}
|
||||
|
||||
ctx = logging.WithSubsystemLoggers(ctx, log)
|
||||
|
||||
local, err := endpoint.NewReceiver(j.fsmap, filters.NewAnyFSVFilter())
|
||||
local, err := endpoint.NewReceiver(fsmap, filters.NewAnyFSVFilter())
|
||||
if err != nil {
|
||||
log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
|
||||
return
|
||||
@ -102,14 +113,14 @@ func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
conn serve.AuthenticatedConn
|
||||
err error
|
||||
}
|
||||
|
||||
func accept(listener net.Listener) <-chan acceptResult {
|
||||
func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult {
|
||||
c := make(chan acceptResult, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
conn, err := listener.Accept(ctx)
|
||||
c <- acceptResult{conn, err}
|
||||
}()
|
||||
return c
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/zrepl/zrepl/tlsconf"
|
||||
"os"
|
||||
"github.com/zrepl/zrepl/daemon/snapper"
|
||||
"github.com/zrepl/zrepl/daemon/serve"
|
||||
)
|
||||
|
||||
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
|
||||
@ -71,6 +72,7 @@ func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Contex
|
||||
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint"))
|
||||
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning"))
|
||||
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot"))
|
||||
ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve"))
|
||||
return ctx
|
||||
}
|
||||
|
||||
|
@ -6,10 +6,69 @@ import (
|
||||
"net"
|
||||
"github.com/zrepl/zrepl/daemon/streamrpcconfig"
|
||||
"github.com/problame/go-streamrpc"
|
||||
"context"
|
||||
"github.com/zrepl/zrepl/logger"
|
||||
"github.com/zrepl/zrepl/zfs"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const contextKeyLog contextKey = 0
|
||||
|
||||
type Logger = logger.Logger
|
||||
|
||||
func WithLogger(ctx context.Context, log Logger) context.Context {
|
||||
return context.WithValue(ctx, contextKeyLog, log)
|
||||
}
|
||||
|
||||
func getLogger(ctx context.Context) Logger {
|
||||
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
|
||||
return log
|
||||
}
|
||||
return logger.NewNullLogger()
|
||||
}
|
||||
|
||||
type AuthenticatedConn interface {
|
||||
net.Conn
|
||||
// ClientIdentity must be a string that satisfies ValidateClientIdentity
|
||||
ClientIdentity() string
|
||||
}
|
||||
|
||||
// A client identity must be a single component in a ZFS filesystem path
|
||||
func ValidateClientIdentity(in string) (err error) {
|
||||
path, err := zfs.NewDatasetPath(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if path.Length() != 1 {
|
||||
return errors.New("client identity must be a single path comonent (not empty, no '/')")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type authConn struct {
|
||||
net.Conn
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
var _ AuthenticatedConn = authConn{}
|
||||
|
||||
func (c authConn) ClientIdentity() string {
|
||||
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return c.clientIdentity
|
||||
}
|
||||
|
||||
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
|
||||
type AuthenticatedListener interface {
|
||||
Addr() (net.Addr)
|
||||
Accept(ctx context.Context) (AuthenticatedConn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type ListenerFactory interface {
|
||||
Listen() (net.Listener, error)
|
||||
Listen() (AuthenticatedListener, error)
|
||||
}
|
||||
|
||||
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
|
||||
@ -25,7 +84,7 @@ func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf
|
||||
lf, lfError = TLSListenerFactoryFromConfig(g, v)
|
||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
case *config.StdinserverServer:
|
||||
lf, lfError = StdinserverListenerFactoryFromConfig(g, v)
|
||||
lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
|
||||
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
|
||||
default:
|
||||
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)
|
||||
|
@ -8,54 +8,133 @@ import (
|
||||
"net"
|
||||
"path"
|
||||
"time"
|
||||
"context"
|
||||
"github.com/pkg/errors"
|
||||
"sync/atomic"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
type StdinserverListenerFactory struct {
|
||||
ClientIdentity string
|
||||
sockpath string
|
||||
ClientIdentities []string
|
||||
Sockdir string
|
||||
}
|
||||
|
||||
func StdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *StdinserverListenerFactory, err error) {
|
||||
func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
|
||||
|
||||
f = &StdinserverListenerFactory{
|
||||
ClientIdentity: in.ClientIdentity,
|
||||
for _, ci := range in.ClientIdentities {
|
||||
if err := ValidateClientIdentity(ci); err != nil {
|
||||
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
|
||||
}
|
||||
}
|
||||
|
||||
f.sockpath = path.Join(g.Serve.StdinServer.SockDir, f.ClientIdentity)
|
||||
f = &multiStdinserverListenerFactory{
|
||||
ClientIdentities: in.ClientIdentities,
|
||||
Sockdir: g.Serve.StdinServer.SockDir,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (f *StdinserverListenerFactory) Listen() (net.Listener, error) {
|
||||
type multiStdinserverListenerFactory struct {
|
||||
ClientIdentities []string
|
||||
Sockdir string
|
||||
}
|
||||
|
||||
if err := nethelpers.PreparePrivateSockpath(f.sockpath); err != nil {
|
||||
return nil, err
|
||||
func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
|
||||
}
|
||||
|
||||
type multiStdinserverAcceptRes struct {
|
||||
conn AuthenticatedConn
|
||||
err error
|
||||
}
|
||||
|
||||
type MultiStdinserverListener struct {
|
||||
listeners []*stdinserverListener
|
||||
accepts chan multiStdinserverAcceptRes
|
||||
closed int32
|
||||
}
|
||||
|
||||
// client identities must be validated
|
||||
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
|
||||
listeners := make([]*stdinserverListener, 0, len(cis))
|
||||
var err error
|
||||
for _, ci := range cis {
|
||||
sockpath := path.Join(sockdir, ci)
|
||||
l := &stdinserverListener{clientIdentity: ci}
|
||||
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
|
||||
break
|
||||
}
|
||||
if l.l, err = netssh.Listen(sockpath); err != nil {
|
||||
break
|
||||
}
|
||||
listeners = append(listeners, l)
|
||||
}
|
||||
|
||||
l, err := netssh.Listen(f.sockpath)
|
||||
if err != nil {
|
||||
for _, l := range listeners {
|
||||
l.Close() // FIXME error reporting?
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return StdinserverListener{l}, nil
|
||||
return &MultiStdinserverListener{listeners: listeners}, nil
|
||||
}
|
||||
|
||||
type StdinserverListener struct {
|
||||
l *netssh.Listener
|
||||
func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
|
||||
|
||||
if m.accepts == nil {
|
||||
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
|
||||
for i := range m.listeners {
|
||||
go func(i int) {
|
||||
for atomic.LoadInt32(&m.closed) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "accepting\n")
|
||||
conn, err := m.listeners[i].Accept(context.TODO())
|
||||
fmt.Fprintf(os.Stderr, "incoming\n")
|
||||
m.accepts <- multiStdinserverAcceptRes{conn, err}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
}
|
||||
|
||||
res := <- m.accepts
|
||||
return res.conn, res.err
|
||||
|
||||
}
|
||||
|
||||
func (l StdinserverListener) Addr() net.Addr {
|
||||
func (m *MultiStdinserverListener) Addr() (net.Addr) {
|
||||
return netsshAddr{}
|
||||
}
|
||||
|
||||
func (l StdinserverListener) Accept() (net.Conn, error) {
|
||||
func (m *MultiStdinserverListener) Close() error {
|
||||
atomic.StoreInt32(&m.closed, 1)
|
||||
var oneErr error
|
||||
for _, l := range m.listeners {
|
||||
if err := l.Close(); err != nil && oneErr == nil {
|
||||
oneErr = err
|
||||
}
|
||||
}
|
||||
return oneErr
|
||||
}
|
||||
|
||||
// a single stdinserverListener (part of multiStinserverListener)
|
||||
type stdinserverListener struct {
|
||||
l *netssh.Listener
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Addr() net.Addr {
|
||||
return netsshAddr{}
|
||||
}
|
||||
|
||||
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
c, err := l.l.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return netsshConnToNetConnAdatper{c}, nil
|
||||
return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
|
||||
}
|
||||
|
||||
func (l StdinserverListener) Close() (err error) {
|
||||
func (l stdinserverListener) Close() (err error) {
|
||||
return l.l.Close()
|
||||
}
|
||||
|
||||
@ -66,12 +145,16 @@ func (netsshAddr) String() string { return "???" }
|
||||
|
||||
type netsshConnToNetConnAdatper struct {
|
||||
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
|
||||
clientIdentity string
|
||||
}
|
||||
|
||||
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
|
||||
|
||||
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
|
||||
|
||||
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
|
||||
|
||||
// FIXME log warning once!
|
||||
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }
|
||||
|
@ -3,19 +3,89 @@ package serve
|
||||
import (
|
||||
"github.com/zrepl/zrepl/config"
|
||||
"net"
|
||||
"github.com/pkg/errors"
|
||||
"context"
|
||||
)
|
||||
|
||||
type TCPListenerFactory struct {
|
||||
Address string
|
||||
address *net.TCPAddr
|
||||
clientMap *ipMap
|
||||
}
|
||||
|
||||
type ipMapEntry struct {
|
||||
ip net.IP
|
||||
ident string
|
||||
}
|
||||
|
||||
type ipMap struct {
|
||||
entries []ipMapEntry
|
||||
}
|
||||
|
||||
func ipMapFromConfig(clients map[string]string) (*ipMap, error) {
|
||||
entries := make([]ipMapEntry, 0, len(clients))
|
||||
for clientIPString, clientIdent := range clients {
|
||||
clientIP := net.ParseIP(clientIPString)
|
||||
if clientIP == nil {
|
||||
return nil, errors.Errorf("cannot parse client IP %q", clientIPString)
|
||||
}
|
||||
if err := ValidateClientIdentity(clientIdent); err != nil {
|
||||
return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString)
|
||||
}
|
||||
entries = append(entries, ipMapEntry{clientIP, clientIdent})
|
||||
}
|
||||
return &ipMap{entries: entries}, nil
|
||||
}
|
||||
|
||||
func (m *ipMap) Get(ip net.IP) (string, error) {
|
||||
for _, e := range m.entries {
|
||||
if e.ip.Equal(ip) {
|
||||
return e.ident, nil
|
||||
}
|
||||
}
|
||||
return "", errors.Errorf("no identity mapping for client IP %s", ip)
|
||||
}
|
||||
|
||||
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", in.Listen)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse listen address")
|
||||
}
|
||||
clientMap, err := ipMapFromConfig(in.Clients)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse client IP map")
|
||||
}
|
||||
lf := &TCPListenerFactory{
|
||||
Address: in.Listen,
|
||||
address: addr,
|
||||
clientMap: clientMap,
|
||||
}
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
func (f *TCPListenerFactory) Listen() (net.Listener, error) {
|
||||
return net.Listen("tcp", f.Address)
|
||||
func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
l, err := net.ListenTCP("tcp", f.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TCPAuthListener{l, f.clientMap}, nil
|
||||
}
|
||||
|
||||
type TCPAuthListener struct {
|
||||
*net.TCPListener
|
||||
clientMap *ipMap
|
||||
}
|
||||
|
||||
func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
nc, err := f.TCPListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientIP := nc.RemoteAddr().(*net.TCPAddr).IP
|
||||
clientIdent, err := f.clientMap.Get(clientIP)
|
||||
if err != nil {
|
||||
getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
|
||||
nc.Close()
|
||||
return nil, err
|
||||
}
|
||||
return authConn{nc, clientIdent}, nil
|
||||
}
|
||||
|
||||
|
@ -8,13 +8,13 @@ import (
|
||||
"github.com/zrepl/zrepl/tlsconf"
|
||||
"net"
|
||||
"time"
|
||||
"context"
|
||||
)
|
||||
|
||||
type TLSListenerFactory struct {
|
||||
address string
|
||||
clientCA *x509.CertPool
|
||||
serverCert tls.Certificate
|
||||
clientCommonName string
|
||||
handshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
@ -23,12 +23,10 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
|
||||
address: in.Listen,
|
||||
}
|
||||
|
||||
if in.Ca == "" || in.Cert == "" || in.Key == "" || in.ClientCN == "" {
|
||||
return nil, errors.New("fields 'ca', 'cert', 'key' and 'client_cn' must be specified")
|
||||
if in.Ca == "" || in.Cert == "" || in.Key == "" {
|
||||
return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
|
||||
}
|
||||
|
||||
lf.clientCommonName = in.ClientCN
|
||||
|
||||
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "cannot parse ca file")
|
||||
@ -42,11 +40,25 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
|
||||
return lf, nil
|
||||
}
|
||||
|
||||
func (f *TLSListenerFactory) Listen() (net.Listener, error) {
|
||||
func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) {
|
||||
l, err := net.Listen("tcp", f.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.clientCommonName, f.handshakeTimeout)
|
||||
return tl, nil
|
||||
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout)
|
||||
return tlsAuthListener{tl}, nil
|
||||
}
|
||||
|
||||
type tlsAuthListener struct {
|
||||
*tlsconf.ClientAuthListener
|
||||
}
|
||||
|
||||
func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
|
||||
c, cn, err := l.ClientAuthListener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authConn{c, cn}, nil
|
||||
}
|
||||
|
||||
|
||||
|
13
main.go
13
main.go
@ -57,6 +57,18 @@ var statusCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
var stdinserverCmd = &cobra.Command{
|
||||
Use: "stdinserver CLIENT_IDENTITY",
|
||||
Short: "start in stdinserver mode (from authorized_keys file)",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
conf, err := config.ParseConfig(rootArgs.configFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.RunStdinserver(conf, args)
|
||||
},
|
||||
}
|
||||
|
||||
var rootArgs struct {
|
||||
configFile string
|
||||
}
|
||||
@ -67,6 +79,7 @@ func init() {
|
||||
rootCmd.AddCommand(daemonCmd)
|
||||
rootCmd.AddCommand(wakeupCmd)
|
||||
rootCmd.AddCommand(statusCmd)
|
||||
rootCmd.AddCommand(stdinserverCmd)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"time"
|
||||
@ -24,13 +23,12 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) {
|
||||
|
||||
type ClientAuthListener struct {
|
||||
l net.Listener
|
||||
clientCommonName string
|
||||
handshakeTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewClientAuthListener(
|
||||
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
|
||||
clientCommonName string, handshakeTimeout time.Duration) *ClientAuthListener {
|
||||
handshakeTimeout time.Duration) *ClientAuthListener {
|
||||
|
||||
if ca == nil {
|
||||
panic(ca)
|
||||
@ -38,9 +36,6 @@ func NewClientAuthListener(
|
||||
if serverCert.Certificate == nil || serverCert.PrivateKey == nil {
|
||||
panic(serverCert)
|
||||
}
|
||||
if clientCommonName == "" {
|
||||
panic(clientCommonName)
|
||||
}
|
||||
|
||||
tlsConf := tls.Config{
|
||||
Certificates: []tls.Certificate{serverCert},
|
||||
@ -51,19 +46,18 @@ func NewClientAuthListener(
|
||||
l = tls.NewListener(l, &tlsConf)
|
||||
return &ClientAuthListener{
|
||||
l,
|
||||
clientCommonName,
|
||||
handshakeTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
|
||||
func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
|
||||
c, err = l.l.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
tlsConn, ok := c.(*tls.Conn)
|
||||
if !ok {
|
||||
return c, err
|
||||
return c, "", err
|
||||
}
|
||||
|
||||
var (
|
||||
@ -83,14 +77,10 @@ func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
|
||||
goto CloseAndErr
|
||||
}
|
||||
cn = peerCerts[0].Subject.CommonName
|
||||
if cn != l.clientCommonName {
|
||||
err = fmt.Errorf("client cert common name does not match client_identity: %q != %q", cn, l.clientCommonName)
|
||||
goto CloseAndErr
|
||||
}
|
||||
return c, nil
|
||||
return c, cn, nil
|
||||
CloseAndErr:
|
||||
c.Close()
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
func (l *ClientAuthListener) Addr() net.Addr {
|
||||
|
Loading…
Reference in New Issue
Block a user