From c8fd4c2d6eb34145754d5e3f1e81752cb426d041 Mon Sep 17 00:00:00 2001 From: Jakub Roztocil Date: Wed, 4 Sep 2019 00:00:03 +0200 Subject: [PATCH] Move compression out of adapter --- httpie/client.py | 66 +++++++++++++++++++----------------------------- 1 file changed, 26 insertions(+), 40 deletions(-) diff --git a/httpie/client.py b/httpie/client.py index 278b85f4..175a89b3 100644 --- a/httpie/client.py +++ b/httpie/client.py @@ -56,7 +56,6 @@ def collect_messages( send_kwargs_mergeable_from_env = make_send_kwargs_mergeable_from_env(args) requests_session = build_requests_session( ssl_version=args.ssl_version, - compress_arg=args.compress, ) if httpie_session: @@ -78,6 +77,8 @@ def collect_messages( request = requests.Request(**request_kwargs) prepared_request = requests_session.prepare_request(request) + if args.compress and prepared_request.body: + compress_body(prepared_request, always=args.compress > 1) response_count = 0 while prepared_request: yield prepared_request @@ -122,60 +123,45 @@ def max_headers(limit): http.client._MAXHEADERS = orig -class HTTPieHTTPAdapter(HTTPAdapter): +def compress_body(request: requests.PreparedRequest, always: bool): + deflater = zlib.compressobj() + body_bytes = ( + request.body + if isinstance(request.body, bytes) + else request.body.encode() + ) + deflated_data = deflater.compress(body_bytes) + deflated_data += deflater.flush() + is_economical = len(deflated_data) < len(body_bytes) + if is_economical or always: + request.body = deflated_data + request.headers['Content-Encoding'] = 'deflate' + request.headers['Content-Length'] = str(len(deflated_data)) - def __init__( - self, - ssl_version=None, - compression_enabled=False, - compress_always=False, - **kwargs, - ): + +class HTTPieHTTPSAdapter(HTTPAdapter): + + def __init__(self, ssl_version=None, **kwargs): self._ssl_version = ssl_version - self._compression_enabled = compression_enabled - self._compress_always = compress_always super().__init__(**kwargs) def init_poolmanager(self, *args, **kwargs): kwargs['ssl_version'] = self._ssl_version super().init_poolmanager(*args, **kwargs) - def send(self, request: requests.PreparedRequest, **kwargs): - if request.body and self._compression_enabled: - self._compress_body(request, always=self._compress_always) - return super().send(request, **kwargs) - - @staticmethod - def _compress_body(request: requests.PreparedRequest, always: bool): - deflater = zlib.compressobj() - body_bytes = ( - request.body - if isinstance(request.body, bytes) - else request.body.encode() - ) - deflated_data = deflater.compress(body_bytes) - deflated_data += deflater.flush() - is_economical = len(deflated_data) < len(body_bytes) - if is_economical or always: - request.body = deflated_data - request.headers['Content-Encoding'] = 'deflate' - request.headers['Content-Length'] = str(len(deflated_data)) - def build_requests_session( - compress_arg: int, ssl_version: str = None, ) -> requests.Session: requests_session = requests.Session() # Install our adapter. - adapter = HTTPieHTTPAdapter( - ssl_version=SSL_VERSION_ARG_MAPPING[ssl_version] if ssl_version else None, - compression_enabled=compress_arg > 0, - compress_always=compress_arg > 1, - ) - requests_session.mount('http://', adapter) - requests_session.mount('https://', adapter) + requests_session.mount('https://', HTTPieHTTPSAdapter( + ssl_version=( + SSL_VERSION_ARG_MAPPING[ssl_version] + if ssl_version else None + ) + )) # Install adapters from plugins. for plugin_cls in plugin_manager.get_transport_plugins():