Move compression out of adapter

This commit is contained in:
Jakub Roztocil 2019-09-04 00:00:03 +02:00
parent 99f8a8c23d
commit c8fd4c2d6e

View File

@ -56,7 +56,6 @@ def collect_messages(
send_kwargs_mergeable_from_env = make_send_kwargs_mergeable_from_env(args) send_kwargs_mergeable_from_env = make_send_kwargs_mergeable_from_env(args)
requests_session = build_requests_session( requests_session = build_requests_session(
ssl_version=args.ssl_version, ssl_version=args.ssl_version,
compress_arg=args.compress,
) )
if httpie_session: if httpie_session:
@ -78,6 +77,8 @@ def collect_messages(
request = requests.Request(**request_kwargs) request = requests.Request(**request_kwargs)
prepared_request = requests_session.prepare_request(request) prepared_request = requests_session.prepare_request(request)
if args.compress and prepared_request.body:
compress_body(prepared_request, always=args.compress > 1)
response_count = 0 response_count = 0
while prepared_request: while prepared_request:
yield prepared_request yield prepared_request
@ -122,60 +123,45 @@ def max_headers(limit):
http.client._MAXHEADERS = orig http.client._MAXHEADERS = orig
class HTTPieHTTPAdapter(HTTPAdapter): def compress_body(request: requests.PreparedRequest, always: bool):
deflater = zlib.compressobj()
body_bytes = (
request.body
if isinstance(request.body, bytes)
else request.body.encode()
)
deflated_data = deflater.compress(body_bytes)
deflated_data += deflater.flush()
is_economical = len(deflated_data) < len(body_bytes)
if is_economical or always:
request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data))
def __init__(
self, class HTTPieHTTPSAdapter(HTTPAdapter):
ssl_version=None,
compression_enabled=False, def __init__(self, ssl_version=None, **kwargs):
compress_always=False,
**kwargs,
):
self._ssl_version = ssl_version self._ssl_version = ssl_version
self._compression_enabled = compression_enabled
self._compress_always = compress_always
super().__init__(**kwargs) super().__init__(**kwargs)
def init_poolmanager(self, *args, **kwargs): def init_poolmanager(self, *args, **kwargs):
kwargs['ssl_version'] = self._ssl_version kwargs['ssl_version'] = self._ssl_version
super().init_poolmanager(*args, **kwargs) super().init_poolmanager(*args, **kwargs)
def send(self, request: requests.PreparedRequest, **kwargs):
if request.body and self._compression_enabled:
self._compress_body(request, always=self._compress_always)
return super().send(request, **kwargs)
@staticmethod
def _compress_body(request: requests.PreparedRequest, always: bool):
deflater = zlib.compressobj()
body_bytes = (
request.body
if isinstance(request.body, bytes)
else request.body.encode()
)
deflated_data = deflater.compress(body_bytes)
deflated_data += deflater.flush()
is_economical = len(deflated_data) < len(body_bytes)
if is_economical or always:
request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data))
def build_requests_session( def build_requests_session(
compress_arg: int,
ssl_version: str = None, ssl_version: str = None,
) -> requests.Session: ) -> requests.Session:
requests_session = requests.Session() requests_session = requests.Session()
# Install our adapter. # Install our adapter.
adapter = HTTPieHTTPAdapter( requests_session.mount('https://', HTTPieHTTPSAdapter(
ssl_version=SSL_VERSION_ARG_MAPPING[ssl_version] if ssl_version else None, ssl_version=(
compression_enabled=compress_arg > 0, SSL_VERSION_ARG_MAPPING[ssl_version]
compress_always=compress_arg > 1, if ssl_version else None
) )
requests_session.mount('http://', adapter) ))
requests_session.mount('https://', adapter)
# Install adapters from plugins. # Install adapters from plugins.
for plugin_cls in plugin_manager.get_transport_plugins(): for plugin_cls in plugin_manager.get_transport_plugins():