mirror of
https://github.com/rclone/rclone.git
synced 2024-11-25 18:04:55 +01:00
376 lines
7.2 KiB
Go
376 lines
7.2 KiB
Go
package httpclient
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"encoding/xml"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var XmlHeaderBytes []byte = []byte(xml.Header)
|
|
|
|
type ErrorHandlerFunc func(*http.Response, error) error
|
|
type PostHookFunc func(*http.Request, *http.Response) error
|
|
|
|
type HTTPClient struct {
|
|
BaseURL *url.URL
|
|
Headers http.Header
|
|
Client *http.Client
|
|
PostHooks map[int]PostHookFunc
|
|
errorHandler ErrorHandlerFunc
|
|
rateLimited bool
|
|
rateLimitChan chan struct{}
|
|
rateLimitTimeout time.Duration
|
|
}
|
|
|
|
func New() (httpClient *HTTPClient) {
|
|
return &HTTPClient{
|
|
Client: HttpClient,
|
|
Headers: make(http.Header),
|
|
PostHooks: make(map[int]PostHookFunc),
|
|
}
|
|
}
|
|
|
|
func Insecure() (httpClient *HTTPClient) {
|
|
httpClient = New()
|
|
httpClient.Client = InsecureHttpClient
|
|
return httpClient
|
|
}
|
|
|
|
var DefaultClient = New()
|
|
|
|
func (c *HTTPClient) SetPostHook(onStatus int, hook PostHookFunc) {
|
|
c.PostHooks[onStatus] = hook
|
|
}
|
|
|
|
func (c *HTTPClient) SetErrorHandler(handler ErrorHandlerFunc) {
|
|
c.errorHandler = handler
|
|
}
|
|
|
|
func (c *HTTPClient) SetRateLimit(limit int, timeout time.Duration) {
|
|
c.rateLimited = true
|
|
c.rateLimitChan = make(chan struct{}, limit)
|
|
|
|
for i := 0; i < limit; i++ {
|
|
c.rateLimitChan <- struct{}{}
|
|
}
|
|
|
|
c.rateLimitTimeout = timeout
|
|
}
|
|
|
|
func (c *HTTPClient) buildURL(req *RequestData) *url.URL {
|
|
bu := c.BaseURL
|
|
|
|
rpath := req.Path
|
|
|
|
if strings.HasSuffix(bu.Path, "/") && strings.HasPrefix(rpath, "/") {
|
|
rpath = rpath[1:]
|
|
}
|
|
|
|
opaque := EscapePath(bu.Path + rpath)
|
|
|
|
u := &url.URL{
|
|
Scheme: bu.Scheme,
|
|
Host: bu.Host,
|
|
Opaque: opaque,
|
|
}
|
|
|
|
if req.Params != nil {
|
|
u.RawQuery = req.Params.Encode()
|
|
}
|
|
|
|
return u
|
|
}
|
|
|
|
func (c *HTTPClient) setHeaders(req *RequestData, httpReq *http.Request) {
|
|
switch req.RespEncoding {
|
|
case EncodingJSON:
|
|
httpReq.Header.Set("Accept", "application/json")
|
|
case EncodingXML:
|
|
httpReq.Header.Set("Accept", "application/xml")
|
|
}
|
|
|
|
if c.Headers != nil {
|
|
for key, values := range c.Headers {
|
|
for _, value := range values {
|
|
httpReq.Header.Set(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
if req.Headers != nil {
|
|
for key, values := range req.Headers {
|
|
for _, value := range values {
|
|
httpReq.Header.Set(key, value)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *HTTPClient) checkStatus(req *RequestData, response *http.Response) (err error) {
|
|
if req.ExpectedStatus != nil {
|
|
statusOk := false
|
|
|
|
for _, status := range req.ExpectedStatus {
|
|
if response.StatusCode == status {
|
|
statusOk = true
|
|
}
|
|
}
|
|
|
|
if !statusOk {
|
|
lr := io.LimitReader(response.Body, 10*1024)
|
|
contentBytes, _ := ioutil.ReadAll(lr)
|
|
content := string(contentBytes)
|
|
|
|
err = InvalidStatusError{
|
|
Expected: req.ExpectedStatus,
|
|
Got: response.StatusCode,
|
|
Headers: response.Header,
|
|
Content: content,
|
|
}
|
|
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HTTPClient) unmarshalResponse(req *RequestData, response *http.Response) (err error) {
|
|
var buf []byte
|
|
|
|
switch req.RespEncoding {
|
|
case EncodingJSON:
|
|
defer response.Body.Close()
|
|
|
|
if buf, err = ioutil.ReadAll(response.Body); err != nil {
|
|
return err
|
|
}
|
|
|
|
err = json.Unmarshal(buf, req.RespValue)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
|
|
case EncodingXML:
|
|
defer response.Body.Close()
|
|
|
|
if buf, err = ioutil.ReadAll(response.Body); err != nil {
|
|
return err
|
|
}
|
|
|
|
err = xml.Unmarshal(buf, req.RespValue)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
switch req.RespValue.(type) {
|
|
case *[]byte:
|
|
defer response.Body.Close()
|
|
|
|
if buf, err = ioutil.ReadAll(response.Body); err != nil {
|
|
return err
|
|
}
|
|
|
|
respVal := req.RespValue.(*[]byte)
|
|
*respVal = buf
|
|
|
|
return nil
|
|
}
|
|
|
|
if req.RespConsume {
|
|
defer response.Body.Close()
|
|
ioutil.ReadAll(response.Body)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *HTTPClient) marshalRequest(req *RequestData) (err error) {
|
|
if req.ReqReader != nil || req.ReqValue == nil {
|
|
return nil
|
|
}
|
|
|
|
if req.Headers == nil {
|
|
req.Headers = make(http.Header)
|
|
}
|
|
|
|
var buf []byte
|
|
|
|
switch req.ReqEncoding {
|
|
case EncodingJSON:
|
|
buf, err = json.Marshal(req.ReqValue)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.ReqReader = bytes.NewReader(buf)
|
|
req.Headers.Set("Content-Type", "application/json")
|
|
req.Headers.Set("Content-Length", fmt.Sprintf("%d", len(buf)))
|
|
|
|
req.ReqContentLength = int64(len(buf))
|
|
|
|
return nil
|
|
|
|
case EncodingXML:
|
|
buf, err = xml.Marshal(req.ReqValue)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
buf = append(XmlHeaderBytes, buf...)
|
|
|
|
req.ReqReader = bytes.NewReader(buf)
|
|
req.Headers.Set("Content-Type", "application/xml")
|
|
req.Headers.Set("Content-Length", fmt.Sprintf("%d", len(buf)))
|
|
|
|
req.ReqContentLength = int64(len(buf))
|
|
|
|
return nil
|
|
|
|
case EncodingForm:
|
|
if data, ok := req.ReqValue.(url.Values); ok {
|
|
formStr := data.Encode()
|
|
req.ReqReader = strings.NewReader(formStr)
|
|
req.Headers.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
req.Headers.Set("Content-Length", fmt.Sprintf("%d", len(formStr)))
|
|
|
|
req.ReqContentLength = int64(len(formStr))
|
|
|
|
return nil
|
|
} else {
|
|
return fmt.Errorf("HTTPClient: invalid ReqValue type %T", req.ReqValue)
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("HTTPClient: invalid ReqEncoding: %s", req.ReqEncoding)
|
|
}
|
|
|
|
func (c *HTTPClient) runPostHook(req *http.Request, response *http.Response) (err error) {
|
|
hook, ok := c.PostHooks[response.StatusCode]
|
|
|
|
if ok {
|
|
err = hook(req, response)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *HTTPClient) Request(req *RequestData) (response *http.Response, err error) {
|
|
err = c.marshalRequest(req)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
r, err := http.NewRequest(req.Method, req.FullURL, req.ReqReader)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if req.Context != nil {
|
|
r = r.WithContext(req.Context)
|
|
}
|
|
|
|
r.ContentLength = req.ReqContentLength
|
|
|
|
if req.FullURL == "" {
|
|
r.URL = c.buildURL(req)
|
|
r.Host = r.URL.Host
|
|
}
|
|
|
|
c.setHeaders(req, r)
|
|
|
|
if c.rateLimited {
|
|
if c.rateLimitTimeout > 0 {
|
|
select {
|
|
case t := <-c.rateLimitChan:
|
|
defer func() {
|
|
c.rateLimitChan <- t
|
|
}()
|
|
case <-time.After(c.rateLimitTimeout):
|
|
return nil, RateLimitTimeoutError
|
|
}
|
|
} else {
|
|
t := <-c.rateLimitChan
|
|
defer func() {
|
|
c.rateLimitChan <- t
|
|
}()
|
|
}
|
|
}
|
|
|
|
isTraceEnabled := os.Getenv("HTTPCLIENT_TRACE") != ""
|
|
|
|
if isTraceEnabled {
|
|
requestBytes, _ := httputil.DumpRequestOut(r, true)
|
|
fmt.Println(string(requestBytes))
|
|
}
|
|
|
|
if req.IgnoreRedirects {
|
|
transport := c.Client.Transport
|
|
|
|
if transport == nil {
|
|
transport = http.DefaultTransport
|
|
}
|
|
|
|
response, err = transport.RoundTrip(r)
|
|
} else {
|
|
response, err = c.Client.Do(r)
|
|
}
|
|
|
|
if err != nil {
|
|
if req.Context != nil {
|
|
// If we got an error, and the context has been canceled,
|
|
// the context's error is probably more useful.
|
|
select {
|
|
case <-req.Context.Done():
|
|
err = req.Context.Err()
|
|
default:
|
|
}
|
|
}
|
|
if c.errorHandler != nil {
|
|
err = c.errorHandler(response, err)
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
if isTraceEnabled {
|
|
responseBytes, _ := httputil.DumpResponse(response, true)
|
|
fmt.Println(string(responseBytes))
|
|
}
|
|
|
|
if err = c.runPostHook(r, response); err != nil {
|
|
return response, err
|
|
}
|
|
|
|
if err = c.checkStatus(req, response); err != nil {
|
|
defer response.Body.Close()
|
|
return response, err
|
|
}
|
|
|
|
if err = c.unmarshalResponse(req, response); err != nil {
|
|
return response, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|