diff --git a/daemon/serve/serve_tls.go b/daemon/serve/serve_tls.go index 9172b32..bc95e41 100644 --- a/daemon/serve/serve_tls.go +++ b/daemon/serve/serve_tls.go @@ -3,6 +3,7 @@ package serve import ( "crypto/tls" "crypto/x509" + "fmt" "github.com/pkg/errors" "github.com/zrepl/zrepl/config" "github.com/zrepl/zrepl/tlsconf" @@ -16,6 +17,7 @@ type TLSListenerFactory struct { clientCA *x509.CertPool serverCert tls.Certificate handshakeTimeout time.Duration + clientCNs map[string]struct{} } func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TLSListenerFactory, err error) { @@ -38,6 +40,15 @@ func TLSListenerFactoryFromConfig(c *config.Global, in *config.TLSServe) (lf *TL return nil, errors.Wrap(err, "cannot parse cer/key pair") } + lf.clientCNs = make(map[string]struct{}, len(in.ClientCNs)) + for i, cn := range in.ClientCNs { + if err := ValidateClientIdentity(cn); err != nil { + return nil, errors.Wrapf(err, "unsuitable client_cn #%d %q", i, cn) + } + // dupes are ok fr now + lf.clientCNs[cn] = struct{}{} + } + return lf, nil } @@ -47,11 +58,12 @@ func (f *TLSListenerFactory) Listen() (AuthenticatedListener, error) { return nil, err } tl := tlsconf.NewClientAuthListener(l, f.clientCA, f.serverCert, f.handshakeTimeout) - return tlsAuthListener{tl}, nil + return tlsAuthListener{tl, f.clientCNs}, nil } type tlsAuthListener struct { *tlsconf.ClientAuthListener + clientCNs map[string]struct{} } func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) { @@ -59,6 +71,12 @@ func (l tlsAuthListener) Accept(ctx context.Context) (AuthenticatedConn, error) if err != nil { return nil, err } + if _, ok := l.clientCNs[cn]; !ok { + if err := c.Close(); err != nil { + getLogger(ctx).WithError(err).Error("error closing connection with unauthorized common name") + } + return nil, fmt.Errorf("unauthorized client common name %q from %s", cn, c.RemoteAddr()) + } return authConn{c, cn}, nil }