Fix --ssl with --compress; refactor client

This commit is contained in:
Jakub Roztocil 2019-08-31 17:52:56 +02:00
parent aba3b1ec01
commit 224519e0e2

View File

@ -1,19 +1,20 @@
import argparse
import http.client
import json
import sys
import http.client
import requests
import zlib
from contextlib import contextmanager
from requests.adapters import HTTPAdapter
from requests.structures import CaseInsensitiveDict
from pathlib import Path
from httpie import sessions
from httpie import __version__
import requests
from requests.adapters import HTTPAdapter
from httpie import __version__, sessions
from httpie.cli.constants import SSL_VERSION_ARG_MAPPING
from httpie.cli.dicts import RequestHeadersDict
from httpie.plugins import plugin_manager
from httpie.utils import repr_dict_nice
import zlib
try:
# https://urllib3.readthedocs.io/en/latest/security.html
@ -31,7 +32,7 @@ except (ImportError, AttributeError):
FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8'
JSON_CONTENT_TYPE = 'application/json'
JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*'
DEFAULT_UA = 'HTTPie/%s' % __version__
DEFAULT_UA = f'HTTPie/{__version__}'
# noinspection PyProtectedMember
@ -48,46 +49,40 @@ def max_headers(limit):
class HTTPieHTTPAdapter(HTTPAdapter):
def __init__(self, ssl_version=None, **kwargs):
def __init__(self, ssl_version=None, compress=0, **kwargs):
self._ssl_version = ssl_version
self._compress = compress
super().__init__(**kwargs)
def init_poolmanager(self, *args, **kwargs):
kwargs['ssl_version'] = self._ssl_version
super().init_poolmanager(*args, **kwargs)
class ContentCompressionHttpAdapter(HTTPAdapter):
def __init__(self, compress, **kwargs):
self.compress = compress
super().__init__(**kwargs)
def send(self, request, **kwargs):
if request.body and self.compress > 0:
deflater = zlib.compressobj()
if isinstance(request.body, bytes):
deflated_data = deflater.compress(request.body)
else:
deflated_data = deflater.compress(request.body.encode())
deflated_data += deflater.flush()
if len(deflated_data) < len(request.body) or self.compress > 1:
request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data))
def send(self, request: requests.PreparedRequest, **kwargs):
if self._compress and request.body:
self._compress_body(request, self._compress)
return super().send(request, **kwargs)
@staticmethod
def _compress_body(request: requests.PreparedRequest, compress: int):
deflater = zlib.compressobj()
if isinstance(request.body, bytes):
deflated_data = deflater.compress(request.body)
else:
deflated_data = deflater.compress(request.body.encode())
deflated_data += deflater.flush()
if len(deflated_data) < len(request.body) or compress > 1:
request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data))
def get_requests_session(ssl_version, compress):
def get_requests_session(ssl_version: str, compress: int) -> requests.Session:
requests_session = requests.Session()
requests_session.mount(
'https://',
HTTPieHTTPAdapter(ssl_version=ssl_version)
)
if compress:
adapter = ContentCompressionHttpAdapter(compress)
for prefix in ['http://', 'https://']:
requests_session.mount(prefix, adapter)
adapter = HTTPieHTTPAdapter(ssl_version=ssl_version, compress=compress)
for prefix in ['http://', 'https://']:
requests_session.mount(prefix, adapter)
for cls in plugin_manager.get_transport_plugins():
transport_plugin = cls()
requests_session.mount(prefix=transport_plugin.prefix,
@ -95,7 +90,10 @@ def get_requests_session(ssl_version, compress):
return requests_session
def get_response(args, config_dir):
def get_response(
args: argparse.Namespace,
config_dir: Path
) -> requests.Response:
"""Send the request and return a `request.Response`."""
ssl_version = None
@ -123,32 +121,29 @@ def get_response(args, config_dir):
return response
def dump_request(kwargs):
def dump_request(kwargs: dict):
sys.stderr.write('\n>>> requests.request(**%s)\n\n'
% repr_dict_nice(kwargs))
def finalize_headers(headers):
final_headers = {}
def finalize_headers(headers: RequestHeadersDict) -> RequestHeadersDict:
final_headers = RequestHeadersDict()
for name, value in headers.items():
if value is not None:
# >leading or trailing LWS MAY be removed without
# >changing the semantics of the field value"
# -https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
# Also, requests raises `InvalidHeader` for leading spaces.
value = value.strip()
if isinstance(value, str):
# See: https://github.com/jakubroztocil/httpie/issues/212
value = value.encode('utf8')
final_headers[name] = value
return final_headers
def get_default_headers(args):
default_headers = CaseInsensitiveDict({
def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict:
default_headers = RequestHeadersDict({
'User-Agent': DEFAULT_UA
})
@ -165,7 +160,7 @@ def get_default_headers(args):
return default_headers
def get_requests_kwargs(args, base_headers=None):
def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict:
"""
Translate our `args` into `requests.request` keyword arguments.