mirror of
https://github.com/rclone/rclone.git
synced 2024-12-31 11:29:15 +01:00
fs/http: reload client certificates on expiry
In corporate environments, client certificates have short life times for added security, and they get renewed automatically. This means that client certificate can expire in the middle of long running command such as `mount`. This commit attempts to reload the client certificates 30s before they expire. This will be active for all backends which use HTTP.
This commit is contained in:
parent
dcecb0ede4
commit
f26d2c6ba8
@ -69,6 +69,13 @@ func NewTransportCustom(ctx context.Context, customize func(*http.Transport)) ht
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err)
|
||||
}
|
||||
if cert.Leaf == nil {
|
||||
// Leaf is always the first certificate
|
||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse the certificate")
|
||||
}
|
||||
}
|
||||
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
@ -148,17 +155,24 @@ type Transport struct {
|
||||
userAgent string
|
||||
headers []*fs.HTTPOption
|
||||
metrics *Metrics
|
||||
// Filename of the client cert in case we need to reload it
|
||||
clientCert string
|
||||
clientKey string
|
||||
// Mutex for serializing attempts at reloading the certificates
|
||||
reloadMutex sync.Mutex
|
||||
}
|
||||
|
||||
// newTransport wraps the http.Transport passed in and logs all
|
||||
// roundtrips including the body if logBody is set.
|
||||
func newTransport(ci *fs.ConfigInfo, transport *http.Transport) *Transport {
|
||||
return &Transport{
|
||||
Transport: transport,
|
||||
dump: ci.Dump,
|
||||
userAgent: ci.UserAgent,
|
||||
headers: ci.Headers,
|
||||
metrics: DefaultMetrics,
|
||||
Transport: transport,
|
||||
dump: ci.Dump,
|
||||
userAgent: ci.UserAgent,
|
||||
headers: ci.Headers,
|
||||
metrics: DefaultMetrics,
|
||||
clientCert: ci.ClientCert,
|
||||
clientKey: ci.ClientKey,
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,8 +261,44 @@ func cleanAuths(buf []byte) []byte {
|
||||
return buf
|
||||
}
|
||||
|
||||
var expireWindow = 30 * time.Second
|
||||
|
||||
func isCertificateExpired(cc *tls.Config) bool {
|
||||
return len(cc.Certificates) > 0 && cc.Certificates[0].Leaf != nil && time.Until(cc.Certificates[0].Leaf.NotAfter) < expireWindow
|
||||
}
|
||||
|
||||
func (t *Transport) reloadCertificates() {
|
||||
t.reloadMutex.Lock()
|
||||
defer t.reloadMutex.Unlock()
|
||||
// Check that the certificate is expired before trying to reload it
|
||||
// it might have been reloaded while we were waiting to lock the mutex
|
||||
if !isCertificateExpired(t.TLSClientConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(t.clientCert, t.clientKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err)
|
||||
}
|
||||
// Check if we need to parse the certificate again, we need it
|
||||
// for checking the expiration date
|
||||
if cert.Leaf == nil {
|
||||
// Leaf is always the first certificate
|
||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse the certificate")
|
||||
}
|
||||
}
|
||||
t.TLSClientConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
|
||||
// Check if certificates are being used and the certificates are expired
|
||||
if isCertificateExpired(t.TLSClientConfig) {
|
||||
t.reloadCertificates()
|
||||
}
|
||||
|
||||
// Limit transactions per second if required
|
||||
accounting.LimitTPS(req.Context())
|
||||
// Force user agent
|
||||
|
@ -1,9 +1,24 @@
|
||||
package fshttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCleanAuth(t *testing.T) {
|
||||
@ -45,3 +60,118 @@ func TestCleanAuths(t *testing.T) {
|
||||
assert.Equal(t, test.want, got, test.in)
|
||||
}
|
||||
}
|
||||
|
||||
var certSerial = int64(0)
|
||||
|
||||
// Create a test certificate and key pair that is valid for a specific
|
||||
// duration
|
||||
func createTestCert(validity time.Duration) (keyPEM []byte, certPEM []byte, err error) {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keyBytes := x509.MarshalPKCS1PrivateKey(key)
|
||||
// PEM encoding of private key
|
||||
keyPEM = pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
},
|
||||
)
|
||||
|
||||
// Now create the certificate
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(validity).Add(expireWindow)
|
||||
|
||||
certSerial += 1
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(certSerial),
|
||||
Subject: pkix.Name{CommonName: "localhost"},
|
||||
SignatureAlgorithm: x509.SHA256WithRSA,
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
BasicConstraintsValid: true,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment | x509.KeyUsageDataEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
certPEM = pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: derBytes,
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func writeTestCert(t *testing.T, ci *fs.ConfigInfo, validity time.Duration) {
|
||||
keyPEM, certPEM, err := createTestCert(1 * time.Second)
|
||||
assert.NoError(t, err, "Cannot create test cert")
|
||||
err = os.WriteFile(ci.ClientCert, certPEM, 0666)
|
||||
assert.NoError(t, err, "Failed to write cert")
|
||||
err = os.WriteFile(ci.ClientKey, keyPEM, 0666)
|
||||
assert.NoError(t, err, "Failed to write key")
|
||||
}
|
||||
|
||||
func TestCertificates(t *testing.T) {
|
||||
startTime := time.Now()
|
||||
// Starting a TLS server
|
||||
expectedSerial := int64(0)
|
||||
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cert := r.TLS.PeerCertificates
|
||||
require.Greater(t, len(cert), 0, "No certificates received")
|
||||
expectedSerial += 1
|
||||
assert.Equal(t, expectedSerial, cert[0].SerialNumber.Int64(), "Did not get the correct serial number in certificate")
|
||||
// Check that the certificate hasn't expired. We cannot use cert validation
|
||||
// functions because those check for signature as well and our certificates
|
||||
// are not properly signed
|
||||
if time.Now().After(cert[0].NotAfter) {
|
||||
assert.Fail(t, "Certificate expired", "Certificate expires at %s, current time is %s", cert[0].NotAfter.Sub(startTime), time.Since(startTime))
|
||||
}
|
||||
|
||||
// Write some test data to fullfil the request
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
_, _ = fmt.Fprintln(w, "test data")
|
||||
}))
|
||||
defer ts.Close()
|
||||
// Modify servers config to request a client certificate
|
||||
// we cannot validate the certificate since we are not properly signing it
|
||||
ts.TLS.ClientAuth = tls.RequestClientCert
|
||||
|
||||
// Set --client-cert and --client-key in config to
|
||||
// a pair of temp files
|
||||
// create a test cert/key pair and write it to the files
|
||||
ctx := context.TODO()
|
||||
ci := fs.GetConfig(ctx)
|
||||
// Create a test certificate and write it to a temp file
|
||||
ci.ClientCert = t.TempDir() + "client.cert"
|
||||
ci.ClientKey = t.TempDir() + "client.key"
|
||||
validity := 1 * time.Second
|
||||
writeTestCert(t, ci, validity)
|
||||
|
||||
// Now create the client with the above settings
|
||||
// we need to disable TLS verification since we don't
|
||||
// care about server certificate
|
||||
client := NewClient(ctx)
|
||||
tt := client.Transport.(*Transport)
|
||||
tt.TLSClientConfig.InsecureSkipVerify = true
|
||||
|
||||
// Now make requests, the first request should be within
|
||||
// the valid window
|
||||
_, err := client.Get(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Wait for the 2* valid duration of the certificate so that has definitely expired
|
||||
time.Sleep(2 * validity)
|
||||
|
||||
// Create a new cert and write it to files
|
||||
writeTestCert(t, ci, validity)
|
||||
|
||||
// The new cert should be auto-loaded before we make this request
|
||||
_, err = client.Get(ts.URL)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user