zrepl/tlsconf/tlsconf.go

122 lines
2.6 KiB
Go
Raw Normal View History

2018-08-25 12:58:17 +02:00
package tlsconf
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"time"
)
func ParseCAFile(certfile string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
pem, err := ioutil.ReadFile(certfile)
if err != nil {
return nil, err
}
if !pool.AppendCertsFromPEM(pem) {
return nil, errors.New("PEM parsing error")
}
return pool, nil
}
type ClientAuthListener struct {
l net.Listener
clientCommonName string
handshakeTimeout time.Duration
}
func NewClientAuthListener(
l net.Listener, ca *x509.CertPool, serverCert tls.Certificate,
clientCommonName string, handshakeTimeout time.Duration) *ClientAuthListener {
if ca == nil {
panic(ca)
}
if serverCert.Certificate == nil || serverCert.PrivateKey == nil {
panic(serverCert)
}
if clientCommonName == "" {
panic(clientCommonName)
}
tlsConf := tls.Config{
2018-08-25 21:30:25 +02:00
Certificates: []tls.Certificate{serverCert},
2018-08-25 12:58:17 +02:00
ClientCAs: ca,
ClientAuth: tls.RequireAndVerifyClientCert,
PreferServerCipherSuites: true,
}
l = tls.NewListener(l, &tlsConf)
return &ClientAuthListener{
l,
clientCommonName,
handshakeTimeout,
}
}
func (l *ClientAuthListener) Accept() (c net.Conn, err error) {
c, err = l.l.Accept()
if err != nil {
return nil, err
}
tlsConn, ok := c.(*tls.Conn)
if !ok {
return c, err
}
var (
cn string
peerCerts []*x509.Certificate
)
if err = tlsConn.SetDeadline(time.Now().Add(l.handshakeTimeout)); err != nil {
goto CloseAndErr
}
if err = tlsConn.Handshake(); err != nil {
goto CloseAndErr
}
peerCerts = tlsConn.ConnectionState().PeerCertificates
if len(peerCerts) != 1 {
err = errors.New("unexpected number of certificates presented by TLS client")
goto CloseAndErr
}
cn = peerCerts[0].Subject.CommonName
if cn != l.clientCommonName {
err = fmt.Errorf("client cert common name does not match client_identity: %q != %q", cn, l.clientCommonName)
goto CloseAndErr
}
return c, nil
CloseAndErr:
c.Close()
return nil, err
}
func (l *ClientAuthListener) Addr() net.Addr {
return l.l.Addr()
}
func (l *ClientAuthListener) Close() error {
return l.l.Close()
}
func ClientAuthClient(serverName string, rootCA *x509.CertPool, clientCert tls.Certificate) (*tls.Config, error) {
if serverName == "" {
panic(serverName)
}
if rootCA == nil {
panic(rootCA)
}
if clientCert.Certificate == nil || clientCert.PrivateKey == nil {
panic(clientCert)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: rootCA,
2018-08-25 21:30:25 +02:00
ServerName: serverName,
2018-08-25 12:58:17 +02:00
}
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
2018-08-25 21:30:25 +02:00
}