diff --git a/httpie/client.py b/httpie/client.py index b79cb1be..bb5410de 100644 --- a/httpie/client.py +++ b/httpie/client.py @@ -1,19 +1,20 @@ +import argparse +import http.client import json import sys - -import http.client -import requests +import zlib from contextlib import contextmanager -from requests.adapters import HTTPAdapter -from requests.structures import CaseInsensitiveDict +from pathlib import Path -from httpie import sessions -from httpie import __version__ +import requests +from requests.adapters import HTTPAdapter + +from httpie import __version__, sessions from httpie.cli.constants import SSL_VERSION_ARG_MAPPING +from httpie.cli.dicts import RequestHeadersDict from httpie.plugins import plugin_manager from httpie.utils import repr_dict_nice -import zlib try: # https://urllib3.readthedocs.io/en/latest/security.html @@ -31,7 +32,7 @@ except (ImportError, AttributeError): FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8' JSON_CONTENT_TYPE = 'application/json' JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*' -DEFAULT_UA = 'HTTPie/%s' % __version__ +DEFAULT_UA = f'HTTPie/{__version__}' # noinspection PyProtectedMember @@ -48,46 +49,40 @@ def max_headers(limit): class HTTPieHTTPAdapter(HTTPAdapter): - def __init__(self, ssl_version=None, **kwargs): + def __init__(self, ssl_version=None, compress=0, **kwargs): self._ssl_version = ssl_version + self._compress = compress super().__init__(**kwargs) def init_poolmanager(self, *args, **kwargs): kwargs['ssl_version'] = self._ssl_version super().init_poolmanager(*args, **kwargs) - -class ContentCompressionHttpAdapter(HTTPAdapter): - - def __init__(self, compress, **kwargs): - self.compress = compress - super().__init__(**kwargs) - - def send(self, request, **kwargs): - if request.body and self.compress > 0: - deflater = zlib.compressobj() - if isinstance(request.body, bytes): - deflated_data = deflater.compress(request.body) - else: - deflated_data = deflater.compress(request.body.encode()) - deflated_data += deflater.flush() - if len(deflated_data) < len(request.body) or self.compress > 1: - request.body = deflated_data - request.headers['Content-Encoding'] = 'deflate' - request.headers['Content-Length'] = str(len(deflated_data)) + def send(self, request: requests.PreparedRequest, **kwargs): + if self._compress and request.body: + self._compress_body(request, self._compress) return super().send(request, **kwargs) + @staticmethod + def _compress_body(request: requests.PreparedRequest, compress: int): + deflater = zlib.compressobj() + if isinstance(request.body, bytes): + deflated_data = deflater.compress(request.body) + else: + deflated_data = deflater.compress(request.body.encode()) + deflated_data += deflater.flush() + if len(deflated_data) < len(request.body) or compress > 1: + request.body = deflated_data + request.headers['Content-Encoding'] = 'deflate' + request.headers['Content-Length'] = str(len(deflated_data)) -def get_requests_session(ssl_version, compress): + +def get_requests_session(ssl_version: str, compress: int) -> requests.Session: requests_session = requests.Session() - requests_session.mount( - 'https://', - HTTPieHTTPAdapter(ssl_version=ssl_version) - ) - if compress: - adapter = ContentCompressionHttpAdapter(compress) - for prefix in ['http://', 'https://']: - requests_session.mount(prefix, adapter) + adapter = HTTPieHTTPAdapter(ssl_version=ssl_version, compress=compress) + for prefix in ['http://', 'https://']: + requests_session.mount(prefix, adapter) + for cls in plugin_manager.get_transport_plugins(): transport_plugin = cls() requests_session.mount(prefix=transport_plugin.prefix, @@ -95,7 +90,10 @@ def get_requests_session(ssl_version, compress): return requests_session -def get_response(args, config_dir): +def get_response( + args: argparse.Namespace, + config_dir: Path +) -> requests.Response: """Send the request and return a `request.Response`.""" ssl_version = None @@ -123,32 +121,29 @@ def get_response(args, config_dir): return response -def dump_request(kwargs): +def dump_request(kwargs: dict): sys.stderr.write('\n>>> requests.request(**%s)\n\n' % repr_dict_nice(kwargs)) -def finalize_headers(headers): - final_headers = {} +def finalize_headers(headers: RequestHeadersDict) -> RequestHeadersDict: + final_headers = RequestHeadersDict() for name, value in headers.items(): if value is not None: - # >leading or trailing LWS MAY be removed without # >changing the semantics of the field value" # -https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html # Also, requests raises `InvalidHeader` for leading spaces. value = value.strip() - if isinstance(value, str): # See: https://github.com/jakubroztocil/httpie/issues/212 value = value.encode('utf8') - final_headers[name] = value return final_headers -def get_default_headers(args): - default_headers = CaseInsensitiveDict({ +def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict: + default_headers = RequestHeadersDict({ 'User-Agent': DEFAULT_UA }) @@ -165,7 +160,7 @@ def get_default_headers(args): return default_headers -def get_requests_kwargs(args, base_headers=None): +def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict: """ Translate our `args` into `requests.request` keyword arguments.