Multi-client servers + bring back stdinserver support

This commit is contained in:
Christian Schwarz 2018-09-04 16:41:54 -07:00
parent e161347e47
commit 308e5e35fb
12 changed files with 356 additions and 71 deletions

41
client/stdinserver.go Normal file
View File

@ -0,0 +1,41 @@
package client
import (
"os"
"context"
"github.com/problame/go-netssh"
"log"
"path"
"github.com/zrepl/zrepl/config"
"errors"
)
func RunStdinserver(config *config.Config, args []string) error {
// NOTE: the netssh proxying protocol requires exiting with non-zero status if anything goes wrong
defer os.Exit(1)
log := log.New(os.Stderr, "", log.LUTC|log.Ldate|log.Ltime)
if len(args) != 1 || args[0] == "" {
err := errors.New("must specify client_identity as positional argument")
return err
}
identity := args[0]
unixaddr := path.Join(config.Global.Serve.StdinServer.SockDir, identity)
log.Printf("proxying client identity '%s' to zrepl daemon '%s'", identity, unixaddr)
ctx := netssh.ContextWithLog(context.TODO(), log)
err := netssh.Proxy(ctx, unixaddr)
if err == nil {
log.Print("proxying finished successfully, exiting with status 0")
os.Exit(0)
}
log.Printf("error proxying: %s", err)
return nil
}

View File

@ -165,7 +165,7 @@ type SSHStdinserverConnect struct {
IdentityFile string `yaml:"identity_file"` IdentityFile string `yaml:"identity_file"`
TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused TransportOpenCommand []string `yaml:"transport_open_command,optional"` //TODO unused
SSHCommand string `yaml:"ssh_command,optional"` //TODO unused SSHCommand string `yaml:"ssh_command,optional"` //TODO unused
Options []string `yaml:"options"` Options []string `yaml:"options,optional"`
DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"` DialTimeout time.Duration `yaml:"dial_timeout,positive,default=10s"`
} }
@ -190,13 +190,13 @@ type TLSServe struct {
Ca string `yaml:"ca"` Ca string `yaml:"ca"`
Cert string `yaml:"cert"` Cert string `yaml:"cert"`
Key string `yaml:"key"` Key string `yaml:"key"`
ClientCN string `yaml:"client_cn"` ClientCNs []string `yaml:"client_cns"`
HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"` HandshakeTimeout time.Duration `yaml:"handshake_timeout,positive,default=10s"`
} }
type StdinserverServer struct { type StdinserverServer struct {
ServeCommon `yaml:",inline"` ServeCommon `yaml:",inline"`
ClientIdentity string `yaml:"client_identity"` ClientIdentities []string `yaml:"client_identities"`
} }
type PruningEnum struct { type PruningEnum struct {

View File

@ -8,7 +8,9 @@ jobs:
ca: "ca.pem" ca: "ca.pem"
cert: "cert.pem" cert: "cert.pem"
key: "key.pem" key: "key.pem"
client_cn: "laptop1" client_cns:
- "laptop1"
- "homeserver"
global: global:
logging: logging:
- type: "tcp" - type: "tcp"

View File

@ -3,7 +3,9 @@ jobs:
type: source type: source
serve: serve:
type: stdinserver type: stdinserver
client_identity: "client1" client_identities:
- "client1"
- "client2"
filesystems: { filesystems: {
"<": true, "<": true,
"secret": false "secret": false

View File

@ -9,31 +9,27 @@ import (
"github.com/zrepl/zrepl/daemon/logging" "github.com/zrepl/zrepl/daemon/logging"
"github.com/zrepl/zrepl/daemon/serve" "github.com/zrepl/zrepl/daemon/serve"
"github.com/zrepl/zrepl/endpoint" "github.com/zrepl/zrepl/endpoint"
"net" "path"
) )
type Sink struct { type Sink struct {
name string name string
l serve.ListenerFactory l serve.ListenerFactory
rpcConf *streamrpc.ConnConfig rpcConf *streamrpc.ConnConfig
fsmap endpoint.FSMap rootDataset string
fsmapInv endpoint.FSFilter
} }
func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) { func SinkFromConfig(g *config.Global, in *config.SinkJob) (s *Sink, err error) {
// FIXME multi client support
s = &Sink{name: in.Name} s = &Sink{name: in.Name}
if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil { if s.l, s.rpcConf, err = serve.FromConfig(g, in.Serve); err != nil {
return nil, errors.Wrap(err, "cannot build server") return nil, errors.Wrap(err, "cannot build server")
} }
fsmap := filters.NewDatasetMapFilter(1, false) // FIXME multi-client support if in.RootDataset == "" {
if err := fsmap.Add("<", in.RootDataset); err != nil { return nil, errors.Wrap(err, "must specify root dataset")
return nil, errors.Wrap(err, "unexpected error: cannot build filesystem mapping")
} }
s.fsmap = fsmap s.rootDataset = in.RootDataset
return s, nil return s, nil
} }
@ -55,6 +51,7 @@ func (j *Sink) Run(ctx context.Context) {
log.WithError(err).Error("cannot listen") log.WithError(err).Error("cannot listen")
return return
} }
defer l.Close()
log.WithField("addr", l.Addr()).Debug("accepting connections") log.WithField("addr", l.Addr()).Debug("accepting connections")
@ -64,10 +61,10 @@ outer:
for { for {
select { select {
case res := <-accept(l): case res := <-accept(ctx, l):
if res.err != nil { if res.err != nil {
log.WithError(err).Info("accept error") log.WithError(res.err).Info("accept error")
break outer continue
} }
connId++ connId++
connLog := log. connLog := log.
@ -82,14 +79,28 @@ outer:
} }
func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) { func (j *Sink) handleConnection(ctx context.Context, conn serve.AuthenticatedConn) {
defer conn.Close()
log := GetLogger(ctx) log := GetLogger(ctx)
log.WithField("addr", conn.RemoteAddr()).Info("handling connection") log.
WithField("addr", conn.RemoteAddr()).
WithField("client_identity", conn.ClientIdentity()).
Info("handling connection")
defer log.Info("finished handling connection") defer log.Info("finished handling connection")
clientRoot := path.Join(j.rootDataset, conn.ClientIdentity())
log.WithField("client_root", clientRoot).Debug("client root")
fsmap := filters.NewDatasetMapFilter(1, false)
if err := fsmap.Add("<", clientRoot); err != nil {
log.WithError(err).
WithField("client_identity", conn.ClientIdentity()).
Error("cannot build client filesystem map (client identity must be a valid ZFS FS name")
}
ctx = logging.WithSubsystemLoggers(ctx, log) ctx = logging.WithSubsystemLoggers(ctx, log)
local, err := endpoint.NewReceiver(j.fsmap, filters.NewAnyFSVFilter()) local, err := endpoint.NewReceiver(fsmap, filters.NewAnyFSVFilter())
if err != nil { if err != nil {
log.WithError(err).Error("unexpected error: cannot convert mapping to filter") log.WithError(err).Error("unexpected error: cannot convert mapping to filter")
return return
@ -102,14 +113,14 @@ func (j *Sink) handleConnection(ctx context.Context, conn net.Conn) {
} }
type acceptResult struct { type acceptResult struct {
conn net.Conn conn serve.AuthenticatedConn
err error err error
} }
func accept(listener net.Listener) <-chan acceptResult { func accept(ctx context.Context, listener serve.AuthenticatedListener) <-chan acceptResult {
c := make(chan acceptResult, 1) c := make(chan acceptResult, 1)
go func() { go func() {
conn, err := listener.Accept() conn, err := listener.Accept(ctx)
c <- acceptResult{conn, err} c <- acceptResult{conn, err}
}() }()
return c return c

View File

@ -15,6 +15,7 @@ import (
"github.com/zrepl/zrepl/tlsconf" "github.com/zrepl/zrepl/tlsconf"
"os" "os"
"github.com/zrepl/zrepl/daemon/snapper" "github.com/zrepl/zrepl/daemon/snapper"
"github.com/zrepl/zrepl/daemon/serve"
) )
func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) { func OutletsFromConfig(in config.LoggingOutletEnumList) (*logger.Outlets, error) {
@ -71,6 +72,7 @@ func WithSubsystemLoggers(ctx context.Context, log logger.Logger) context.Contex
ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint")) ctx = endpoint.WithLogger(ctx, log.WithField(SubsysField, "endpoint"))
ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning")) ctx = pruner.WithLogger(ctx, log.WithField(SubsysField, "pruning"))
ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot")) ctx = snapper.WithLogger(ctx, log.WithField(SubsysField, "snapshot"))
ctx = serve.WithLogger(ctx, log.WithField(SubsysField, "serve"))
return ctx return ctx
} }

View File

@ -6,10 +6,69 @@ import (
"net" "net"
"github.com/zrepl/zrepl/daemon/streamrpcconfig" "github.com/zrepl/zrepl/daemon/streamrpcconfig"
"github.com/problame/go-streamrpc" "github.com/problame/go-streamrpc"
"context"
"github.com/zrepl/zrepl/logger"
"github.com/zrepl/zrepl/zfs"
) )
type contextKey int
const contextKeyLog contextKey = 0
type Logger = logger.Logger
func WithLogger(ctx context.Context, log Logger) context.Context {
return context.WithValue(ctx, contextKeyLog, log)
}
func getLogger(ctx context.Context) Logger {
if log, ok := ctx.Value(contextKeyLog).(Logger); ok {
return log
}
return logger.NewNullLogger()
}
type AuthenticatedConn interface {
net.Conn
// ClientIdentity must be a string that satisfies ValidateClientIdentity
ClientIdentity() string
}
// A client identity must be a single component in a ZFS filesystem path
func ValidateClientIdentity(in string) (err error) {
path, err := zfs.NewDatasetPath(in)
if err != nil {
return err
}
if path.Length() != 1 {
return errors.New("client identity must be a single path comonent (not empty, no '/')")
}
return nil
}
type authConn struct {
net.Conn
clientIdentity string
}
var _ AuthenticatedConn = authConn{}
func (c authConn) ClientIdentity() string {
if err := ValidateClientIdentity(c.clientIdentity); err != nil {
panic(err)
}
return c.clientIdentity
}
// like net.Listener, but with an AuthenticatedConn instead of net.Conn
type AuthenticatedListener interface {
Addr() (net.Addr)
Accept(ctx context.Context) (AuthenticatedConn, error)
Close() error
}
type ListenerFactory interface { type ListenerFactory interface {
Listen() (net.Listener, error) Listen() (AuthenticatedListener, error)
} }
func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) { func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf *streamrpc.ConnConfig, _ error) {
@ -25,7 +84,7 @@ func FromConfig(g *config.Global, in config.ServeEnum) (lf ListenerFactory, conf
lf, lfError = TLSListenerFactoryFromConfig(g, v) lf, lfError = TLSListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
case *config.StdinserverServer: case *config.StdinserverServer:
lf, lfError = StdinserverListenerFactoryFromConfig(g, v) lf, lfError = MultiStdinserverListenerFactoryFromConfig(g, v)
conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC) conf, rpcErr = streamrpcconfig.FromDaemonConfig(g, v.RPC)
default: default:
return nil, nil, errors.Errorf("internal error: unknown serve type %T", v) return nil, nil, errors.Errorf("internal error: unknown serve type %T", v)

View File

@ -8,54 +8,133 @@ import (
"net" "net"
"path" "path"
"time" "time"
"context"
"github.com/pkg/errors"
"sync/atomic"
"fmt"
"os"
) )
type StdinserverListenerFactory struct { type StdinserverListenerFactory struct {
ClientIdentity string ClientIdentities []string
sockpath string Sockdir string
} }
func StdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *StdinserverListenerFactory, err error) { func MultiStdinserverListenerFactoryFromConfig(g *config.Global, in *config.StdinserverServer) (f *multiStdinserverListenerFactory, err error) {
f = &StdinserverListenerFactory{ for _, ci := range in.ClientIdentities {
ClientIdentity: in.ClientIdentity, if err := ValidateClientIdentity(ci); err != nil {
return nil, errors.Wrapf(err, "invalid client identity %q", ci)
}
} }
f.sockpath = path.Join(g.Serve.StdinServer.SockDir, f.ClientIdentity) f = &multiStdinserverListenerFactory{
ClientIdentities: in.ClientIdentities,
Sockdir: g.Serve.StdinServer.SockDir,
}
return return
} }
func (f *StdinserverListenerFactory) Listen() (net.Listener, error) { type multiStdinserverListenerFactory struct {
ClientIdentities []string
Sockdir string
}
if err := nethelpers.PreparePrivateSockpath(f.sockpath); err != nil { func (f *multiStdinserverListenerFactory) Listen() (AuthenticatedListener, error) {
return nil, err return multiStdinserverListenerFromClientIdentities(f.Sockdir, f.ClientIdentities)
}
type multiStdinserverAcceptRes struct {
conn AuthenticatedConn
err error
}
type MultiStdinserverListener struct {
listeners []*stdinserverListener
accepts chan multiStdinserverAcceptRes
closed int32
}
// client identities must be validated
func multiStdinserverListenerFromClientIdentities(sockdir string, cis []string) (*MultiStdinserverListener, error) {
listeners := make([]*stdinserverListener, 0, len(cis))
var err error
for _, ci := range cis {
sockpath := path.Join(sockdir, ci)
l := &stdinserverListener{clientIdentity: ci}
if err = nethelpers.PreparePrivateSockpath(sockpath); err != nil {
break
}
if l.l, err = netssh.Listen(sockpath); err != nil {
break
}
listeners = append(listeners, l)
} }
l, err := netssh.Listen(f.sockpath)
if err != nil { if err != nil {
for _, l := range listeners {
l.Close() // FIXME error reporting?
}
return nil, err return nil, err
} }
return StdinserverListener{l}, nil return &MultiStdinserverListener{listeners: listeners}, nil
} }
type StdinserverListener struct { func (m *MultiStdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error){
l *netssh.Listener
if m.accepts == nil {
m.accepts = make(chan multiStdinserverAcceptRes, len(m.listeners))
for i := range m.listeners {
go func(i int) {
for atomic.LoadInt32(&m.closed) == 0 {
fmt.Fprintf(os.Stderr, "accepting\n")
conn, err := m.listeners[i].Accept(context.TODO())
fmt.Fprintf(os.Stderr, "incoming\n")
m.accepts <- multiStdinserverAcceptRes{conn, err}
}
}(i)
}
}
res := <- m.accepts
return res.conn, res.err
} }
func (l StdinserverListener) Addr() net.Addr { func (m *MultiStdinserverListener) Addr() (net.Addr) {
return netsshAddr{} return netsshAddr{}
} }
func (l StdinserverListener) Accept() (net.Conn, error) { func (m *MultiStdinserverListener) Close() error {
atomic.StoreInt32(&m.closed, 1)
var oneErr error
for _, l := range m.listeners {
if err := l.Close(); err != nil && oneErr == nil {
oneErr = err
}
}
return oneErr
}
// a single stdinserverListener (part of multiStinserverListener)
type stdinserverListener struct {
l *netssh.Listener
clientIdentity string
}
func (l stdinserverListener) Addr() net.Addr {
return netsshAddr{}
}
func (l stdinserverListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
c, err := l.l.Accept() c, err := l.l.Accept()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return netsshConnToNetConnAdatper{c}, nil return netsshConnToNetConnAdatper{c, l.clientIdentity}, nil
} }
func (l StdinserverListener) Close() (err error) { func (l stdinserverListener) Close() (err error) {
return l.l.Close() return l.l.Close()
} }
@ -66,12 +145,16 @@ func (netsshAddr) String() string { return "???" }
type netsshConnToNetConnAdatper struct { type netsshConnToNetConnAdatper struct {
io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn io.ReadWriteCloser // works for both netssh.SSHConn and netssh.ServeConn
clientIdentity string
} }
func (a netsshConnToNetConnAdatper) ClientIdentity() string { return a.clientIdentity }
func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} } func (netsshConnToNetConnAdatper) LocalAddr() net.Addr { return netsshAddr{} }
func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} } func (netsshConnToNetConnAdatper) RemoteAddr() net.Addr { return netsshAddr{} }
// FIXME log warning once!
func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil } func (netsshConnToNetConnAdatper) SetDeadline(t time.Time) error { return nil }
func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil } func (netsshConnToNetConnAdatper) SetReadDeadline(t time.Time) error { return nil }

View File

@ -3,19 +3,89 @@ package serve
import ( import (
"github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/config"
"net" "net"
"github.com/pkg/errors"
"context"
) )
type TCPListenerFactory struct { type TCPListenerFactory struct {
Address string address *net.TCPAddr
clientMap *ipMap
}
type ipMapEntry struct {
ip net.IP
ident string
}
type ipMap struct {
entries []ipMapEntry
}
func ipMapFromConfig(clients map[string]string) (*ipMap, error) {
entries := make([]ipMapEntry, 0, len(clients))
for clientIPString, clientIdent := range clients {
clientIP := net.ParseIP(clientIPString)
if clientIP == nil {
return nil, errors.Errorf("cannot parse client IP %q", clientIPString)
}
if err := ValidateClientIdentity(clientIdent); err != nil {
return nil, errors.Wrapf(err,"invalid client identity for IP %q", clientIPString)
}
entries = append(entries, ipMapEntry{clientIP, clientIdent})
}
return &ipMap{entries: entries}, nil
}
func (m *ipMap) Get(ip net.IP) (string, error) {
for _, e := range m.entries {
if e.ip.Equal(ip) {
return e.ident, nil
}
}
return "", errors.Errorf("no identity mapping for client IP %s", ip)
} }
func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) { func TCPListenerFactoryFromConfig(c *config.Global, in *config.TCPServe) (*TCPListenerFactory, error) {
addr, err := net.ResolveTCPAddr("tcp", in.Listen)
if err != nil {
return nil, errors.Wrap(err, "cannot parse listen address")
}
clientMap, err := ipMapFromConfig(in.Clients)
if err != nil {
return nil, errors.Wrap(err, "cannot parse client IP map")
}
lf := &TCPListenerFactory{ lf := &TCPListenerFactory{
Address: in.Listen, address: addr,
clientMap: clientMap,
} }
return lf, nil return lf, nil
} }
func (f *TCPListenerFactory) Listen() (net.Listener, error) { func (f *TCPListenerFactory) Listen() (AuthenticatedListener, error) {
return net.Listen("tcp", f.Address) l, err := net.ListenTCP("tcp", f.address)
if err != nil {
return nil, err
}
return &TCPAuthListener{l, f.clientMap}, nil
} }
type TCPAuthListener struct {
*net.TCPListener
clientMap *ipMap
}
func (f *TCPAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
nc, err := f.TCPListener.Accept()
if err != nil {
return nil, err
}
clientIP := nc.RemoteAddr().(*net.TCPAddr).IP
clientIdent, err := f.clientMap.Get(clientIP)
if err != nil {
getLogger(ctx).WithField("ip", clientIP).Error("client IP not in client map")
nc.Close()
return nil, err
}
return authConn{nc, clientIdent}, nil
}

View File

@ -8,13 +8,13 @@ import (
"github.com/zrepl/zrepl/tlsconf" "github.com/zrepl/zrepl/tlsconf"
"net" "net"
"time" "time"
"context"
) )
type TLSListenerFactory struct { type TLSListenerFactory struct {
address string address string
clientCA *x509.CertPool clientCA *x509.CertPool
serverCert tls.Certificate serverCert tls.Certificate
clientCommonName string
handshakeTimeout time.Duration handshakeTimeout time.Duration
} }
@ -23,12 +23,10 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
address: in.Listen, address: in.Listen,
} }
if in.Ca == "" || in.Cert == "" || in.Key == "" || in.ClientCN == "" { if in.Ca == "" || in.Cert == "" || in.Key == "" {
return nil, errors.New("fields 'ca', 'cert', 'key' and 'client_cn' must be specified") return nil, errors.New("fields 'ca', 'cert' and 'key'must be specified")
} }
lf.clientCommonName = in.ClientCN
lf.clientCA, err = tlsconf.ParseCAFile(in.Ca) lf.clientCA, err = tlsconf.ParseCAFile(in.Ca)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "cannot parse ca file") return nil, errors.Wrap(err, "cannot parse ca file")
@ -42,11 +40,25 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL
return lf, nil return lf, nil
} }
func (f *TLSListenerFactory) Listen() (net.Listener, error) { func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) {
l, err := net.Listen("tcp", f.address) l, err := net.Listen("tcp", f.address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.clientCommonName, f.handshakeTimeout) tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout)
return tl, nil return tlsAuthListener{tl}, nil
} }
type tlsAuthListener struct {
*tlsconf.ClientAuthListener
}
func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) {
c, cn, err := l.ClientAuthListener.Accept()
if err != nil {
return nil, err
}
return authConn{c, cn}, nil
}

13
main.go
View File

@ -57,6 +57,18 @@ var statusCmd = &cobra.Command{
}, },
} }
var stdinserverCmd = &cobra.Command{
Use: "stdinserver CLIENT_IDENTITY",
Short: "start in stdinserver mode (from authorized_keys file)",
RunE: func(cmd *cobra.Command, args []string) error {
conf, err := config.ParseConfig(rootArgs.configFile)
if err != nil {
return err
}
return client.RunStdinserver(conf, args)
},
}
var rootArgs struct { var rootArgs struct {
configFile string configFile string
} }
@ -67,6 +79,7 @@ func init() {
rootCmd.AddCommand(daemonCmd) rootCmd.AddCommand(daemonCmd)
rootCmd.AddCommand(wakeupCmd) rootCmd.AddCommand(wakeupCmd)
rootCmd.AddCommand(statusCmd) rootCmd.AddCommand(statusCmd)
rootCmd.AddCommand(stdinserverCmd)
} }
func main() { func main() {

View File

@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"io/ioutil" "io/ioutil"
"net" "net"
"time" "time"
@ -24,13 +23,12 @@ func ParseCAFile(certfile string) (*x509.CertPool, error) {
type ClientAuthListener struct { type ClientAuthListener struct {
l net.Listener l net.Listener
clientCommonName string
handshakeTimeout time.Duration handshakeTimeout time.Duration
} }
func NewClientAuthListener( func NewClientAuthListener(
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate, l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
clientCommonName string, handshakeTimeout time.Duration) *ClientAuthListener { handshakeTimeout time.Duration) *ClientAuthListener {
if ca == nil { if ca == nil {
panic(ca) panic(ca)
@ -38,9 +36,6 @@ func NewClientAuthListener(
if serverCert.Certificate == nil || serverCert.PrivateKey == nil { if serverCert.Certificate == nil || serverCert.PrivateKey == nil {
panic(serverCert) panic(serverCert)
} }
if clientCommonName == "" {
panic(clientCommonName)
}
tlsConf := tls.Config{ tlsConf := tls.Config{
Certificates: []tls.Certificate{serverCert}, Certificates: []tls.Certificate{serverCert},
@ -51,19 +46,18 @@ func NewClientAuthListener(
l = tls.NewListener(l, &tlsConf) l = tls.NewListener(l, &tlsConf)
return &ClientAuthListener{ return &ClientAuthListener{
l, l,
clientCommonName,
handshakeTimeout, handshakeTimeout,
} }
} }
func (l *ClientAuthListener) Accept() (c net.Conn, err error) { func (l *ClientAuthListener) Accept() (c net.Conn, clientCN string, err error) {
c, err = l.l.Accept() c, err = l.l.Accept()
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
tlsConn, ok := c.(*tls.Conn) tlsConn, ok := c.(*tls.Conn)
if !ok { if !ok {
return c, err return c, "", err
} }
var ( var (
@ -83,14 +77,10 @@ func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
goto CloseAndErr goto CloseAndErr
} }
cn = peerCerts[0].Subject.CommonName cn = peerCerts[0].Subject.CommonName
if cn != l.clientCommonName { return c, cn, nil
err = fmt.Errorf("client cert common name does not match client_identity: %q != %q", cn, l.clientCommonName)
goto CloseAndErr
}
return c, nil
CloseAndErr: CloseAndErr:
c.Close() c.Close()
return nil, err return nil, "", err
} }
func (l *ClientAuthListener) Addr() net.Addr { func (l *ClientAuthListener) Addr() net.Addr {