From 4dffac7a25e9c1098df553138114319a9fbf0531 Mon Sep 17 00:00:00 2001 From: Jakub Roztocil Date: Sun, 1 Sep 2019 11:38:14 +0200 Subject: [PATCH] Refactor client --- httpie/cli/definition.py | 1 + httpie/client.py | 79 +++++++++++++++++++++++++--------------- httpie/sessions.py | 4 +- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/httpie/cli/definition.py b/httpie/cli/definition.py index 5a3fe6ab..201d9f7f 100644 --- a/httpie/cli/definition.py +++ b/httpie/cli/definition.py @@ -169,6 +169,7 @@ content_processing = parser.add_argument_group( content_processing.add_argument( '--compress', '-x', action='count', + default=0, help=""" Content compressed (encoded) with Deflate algorithm. The Content-Encoding header is set to deflate. diff --git a/httpie/client.py b/httpie/client.py index 49d04063..fb81667f 100644 --- a/httpie/client.py +++ b/httpie/client.py @@ -17,18 +17,13 @@ from httpie.utils import repr_dict try: - # https://urllib3.readthedocs.io/en/latest/security.html # noinspection PyPackageRequirements import urllib3 + # urllib3.disable_warnings() except (ImportError, AttributeError): - # In some rare cases, the user may have an old version of the requests - # or urllib3, and there is no method called "disable_warnings." In these - # cases, we don't need to call the method. - # They may get some noisy output but execution shouldn't die. Move on. pass - FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8' JSON_CONTENT_TYPE = 'application/json' JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*' @@ -49,9 +44,16 @@ def max_headers(limit): class HTTPieHTTPAdapter(HTTPAdapter): - def __init__(self, ssl_version=None, compress=0, **kwargs): + def __init__( + self, + ssl_version=None, + compression_enabled=False, + compress_always=False, + **kwargs, + ): self._ssl_version = ssl_version - self._compress = compress + self._compression_enabled = compression_enabled + self._compress_always = compress_always super().__init__(**kwargs) def init_poolmanager(self, *args, **kwargs): @@ -59,34 +61,50 @@ class HTTPieHTTPAdapter(HTTPAdapter): super().init_poolmanager(*args, **kwargs) def send(self, request: requests.PreparedRequest, **kwargs): - if self._compress and request.body: - self._compress_body(request, self._compress) + if request.body and self._compression_enabled: + self._compress_body(request, always=self._compress_always) return super().send(request, **kwargs) @staticmethod - def _compress_body(request: requests.PreparedRequest, compress: int): + def _compress_body(request: requests.PreparedRequest, always: bool): deflater = zlib.compressobj() - if isinstance(request.body, bytes): - deflated_data = deflater.compress(request.body) - else: - deflated_data = deflater.compress(request.body.encode()) + body_bytes = ( + request.body + if isinstance(request.body, bytes) + else request.body.encode() + ) + deflated_data = deflater.compress(body_bytes) deflated_data += deflater.flush() - if len(deflated_data) < len(request.body) or compress > 1: + is_economical = len(deflated_data) < len(body_bytes) + if is_economical or always: request.body = deflated_data request.headers['Content-Encoding'] = 'deflate' request.headers['Content-Length'] = str(len(deflated_data)) -def get_requests_session(ssl_version: str, compress: int) -> requests.Session: +def build_requests_session( + ssl_version: str, + compress_arg: int, +) -> requests.Session: requests_session = requests.Session() - adapter = HTTPieHTTPAdapter(ssl_version=ssl_version, compress=compress) - for prefix in ['http://', 'https://']: - requests_session.mount(prefix, adapter) - for cls in plugin_manager.get_transport_plugins(): - transport_plugin = cls() - requests_session.mount(prefix=transport_plugin.prefix, - adapter=transport_plugin.get_adapter()) + # Install our adapter. + adapter = HTTPieHTTPAdapter( + ssl_version=ssl_version, + compression_enabled=compress_arg > 0, + compress_always=compress_arg > 1, + ) + requests_session.mount('http://', adapter) + requests_session.mount('https://', adapter) + + # Install adapters from plugins. + for plugin_cls in plugin_manager.get_transport_plugins(): + transport_plugin = plugin_cls() + requests_session.mount( + prefix=transport_plugin.prefix, + adapter=transport_plugin.get_adapter(), + ) + return requests_session @@ -100,12 +118,15 @@ def get_response( if args.ssl_version: ssl_version = SSL_VERSION_ARG_MAPPING[args.ssl_version] - requests_session = get_requests_session(ssl_version, args.compress) + requests_session = build_requests_session( + ssl_version=ssl_version, + compress_arg=args.compress + ) requests_session.max_redirects = args.max_redirects with max_headers(args.max_headers): if not args.session and not args.session_read_only: - kwargs = get_requests_kwargs(args) + kwargs = make_requests_kwargs(args) if args.debug: dump_request(kwargs) response = requests_session.request(**kwargs) @@ -142,7 +163,7 @@ def finalize_headers(headers: RequestHeadersDict) -> RequestHeadersDict: return final_headers -def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict: +def make_default_headers(args: argparse.Namespace) -> RequestHeadersDict: default_headers = RequestHeadersDict({ 'User-Agent': DEFAULT_UA }) @@ -160,7 +181,7 @@ def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict: return default_headers -def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict: +def make_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict: """ Translate our `args` into `requests.request` keyword arguments. @@ -177,7 +198,7 @@ def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict: data = '' # Finalize headers. - headers = get_default_headers(args) + headers = make_default_headers(args) if base_headers: headers.update(base_headers) headers.update(args.headers) diff --git a/httpie/sessions.py b/httpie/sessions.py index 59ae7343..69289647 100644 --- a/httpie/sessions.py +++ b/httpie/sessions.py @@ -36,7 +36,7 @@ def get_response( aspects of the session to the request. """ - from .client import get_requests_kwargs, dump_request + from .client import make_requests_kwargs, dump_request if os.path.sep in session_name: path = os.path.expanduser(session_name) else: @@ -56,7 +56,7 @@ def get_response( session = Session(path) session.load() - kwargs = get_requests_kwargs(args, base_headers=session.headers) + kwargs = make_requests_kwargs(args, base_headers=session.headers) if args.debug: dump_request(kwargs) session.update_headers(kwargs['headers'])