From 0b6cdb7677c99b8f360dd51d9a38331b23c2c909 Mon Sep 17 00:00:00 2001 From: Nick Craig-Wood Date: Wed, 16 Oct 2019 11:21:26 +0100 Subject: [PATCH] fshttp: allow Transport to be customized #3631 --- fs/fshttp/http.go | 119 +++++++++++++++++++++++++--------------------- 1 file changed, 66 insertions(+), 53 deletions(-) diff --git a/fs/fshttp/http.go b/fs/fshttp/http.go index a3f381401..304be7c1a 100644 --- a/fs/fshttp/http.go +++ b/fs/fshttp/http.go @@ -127,72 +127,85 @@ func ResetTransport() { noTransport = new(sync.Once) } +// NewTransportCustom returns an http.RoundTripper with the correct timeouts. +// The customize function is called if set to give the caller an opportunity to +// customize any defaults in the Transport. +func NewTransportCustom(ci *fs.ConfigInfo, customize func(*http.Transport)) http.RoundTripper { + // Start with a sensible set of defaults then override. + // This also means we get new stuff when it gets added to go + t := new(http.Transport) + setDefaults(t, http.DefaultTransport.(*http.Transport)) + t.Proxy = http.ProxyFromEnvironment + t.MaxIdleConnsPerHost = 2 * (ci.Checkers + ci.Transfers + 1) + t.MaxIdleConns = 2 * t.MaxIdleConnsPerHost + t.TLSHandshakeTimeout = ci.ConnectTimeout + t.ResponseHeaderTimeout = ci.Timeout + + // TLS Config + t.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: ci.InsecureSkipVerify, + } + + // Load client certs + if ci.ClientCert != "" || ci.ClientKey != "" { + if ci.ClientCert == "" || ci.ClientKey == "" { + log.Fatalf("Both --client-cert and --client-key must be set") + } + cert, err := tls.LoadX509KeyPair(ci.ClientCert, ci.ClientKey) + if err != nil { + log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err) + } + t.TLSClientConfig.Certificates = []tls.Certificate{cert} + t.TLSClientConfig.BuildNameToCertificate() + } + + // Load CA cert + if ci.CaCert != "" { + caCert, err := ioutil.ReadFile(ci.CaCert) + if err != nil { + log.Fatalf("Failed to read --ca-cert: %v", err) + } + caCertPool := x509.NewCertPool() + ok := caCertPool.AppendCertsFromPEM(caCert) + if !ok { + log.Fatalf("Failed to add certificates from --ca-cert") + } + t.TLSClientConfig.RootCAs = caCertPool + } + + t.DisableCompression = ci.NoGzip + t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialContextTimeout(ctx, network, addr, ci) + } + t.IdleConnTimeout = 60 * time.Second + t.ExpectContinueTimeout = ci.ConnectTimeout + + // customize the transport if required + if customize != nil { + customize(t) + } + + // Wrap that http.Transport in our own transport + return newTransport(ci, t) +} + // NewTransport returns an http.RoundTripper with the correct timeouts func NewTransport(ci *fs.ConfigInfo) http.RoundTripper { (*noTransport).Do(func() { - // Start with a sensible set of defaults then override. - // This also means we get new stuff when it gets added to go - t := new(http.Transport) - setDefaults(t, http.DefaultTransport.(*http.Transport)) - t.Proxy = http.ProxyFromEnvironment - t.MaxIdleConnsPerHost = 2 * (ci.Checkers + ci.Transfers + 1) - t.MaxIdleConns = 2 * t.MaxIdleConnsPerHost - t.TLSHandshakeTimeout = ci.ConnectTimeout - t.ResponseHeaderTimeout = ci.Timeout - - // TLS Config - t.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: ci.InsecureSkipVerify, - } - - // Load client certs - if ci.ClientCert != "" || ci.ClientKey != "" { - if ci.ClientCert == "" || ci.ClientKey == "" { - log.Fatalf("Both --client-cert and --client-key must be set") - } - cert, err := tls.LoadX509KeyPair(ci.ClientCert, ci.ClientKey) - if err != nil { - log.Fatalf("Failed to load --client-cert/--client-key pair: %v", err) - } - t.TLSClientConfig.Certificates = []tls.Certificate{cert} - t.TLSClientConfig.BuildNameToCertificate() - } - - // Load CA cert - if ci.CaCert != "" { - caCert, err := ioutil.ReadFile(ci.CaCert) - if err != nil { - log.Fatalf("Failed to read --ca-cert: %v", err) - } - caCertPool := x509.NewCertPool() - ok := caCertPool.AppendCertsFromPEM(caCert) - if !ok { - log.Fatalf("Failed to add certificates from --ca-cert") - } - t.TLSClientConfig.RootCAs = caCertPool - } - - t.DisableCompression = ci.NoGzip - t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialContextTimeout(ctx, network, addr, ci) - } - t.IdleConnTimeout = 60 * time.Second - t.ExpectContinueTimeout = ci.ConnectTimeout - // Wrap that http.Transport in our own transport - transport = newTransport(ci, t) + transport = NewTransportCustom(ci, nil) }) return transport } // NewClient returns an http.Client with the correct timeouts func NewClient(ci *fs.ConfigInfo) *http.Client { - transport := &http.Client{ + client := &http.Client{ Transport: NewTransport(ci), } if ci.Cookie { - transport.Jar = cookieJar + client.Jar = cookieJar } - return transport + return client } // Transport is a our http Transport which wraps an http.Transport