Fix --ssl with --compress; refactor client

This commit is contained in:
Jakub Roztocil 2019-08-31 17:52:56 +02:00
parent aba3b1ec01
commit 224519e0e2

View File

@ -1,19 +1,20 @@
import argparse
import http.client
import json import json
import sys import sys
import zlib
import http.client
import requests
from contextlib import contextmanager from contextlib import contextmanager
from requests.adapters import HTTPAdapter from pathlib import Path
from requests.structures import CaseInsensitiveDict
from httpie import sessions import requests
from httpie import __version__ from requests.adapters import HTTPAdapter
from httpie import __version__, sessions
from httpie.cli.constants import SSL_VERSION_ARG_MAPPING from httpie.cli.constants import SSL_VERSION_ARG_MAPPING
from httpie.cli.dicts import RequestHeadersDict
from httpie.plugins import plugin_manager from httpie.plugins import plugin_manager
from httpie.utils import repr_dict_nice from httpie.utils import repr_dict_nice
import zlib
try: try:
# https://urllib3.readthedocs.io/en/latest/security.html # https://urllib3.readthedocs.io/en/latest/security.html
@ -31,7 +32,7 @@ except (ImportError, AttributeError):
FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8' FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8'
JSON_CONTENT_TYPE = 'application/json' JSON_CONTENT_TYPE = 'application/json'
JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*' JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*'
DEFAULT_UA = 'HTTPie/%s' % __version__ DEFAULT_UA = f'HTTPie/{__version__}'
# noinspection PyProtectedMember # noinspection PyProtectedMember
@ -48,46 +49,40 @@ def max_headers(limit):
class HTTPieHTTPAdapter(HTTPAdapter): class HTTPieHTTPAdapter(HTTPAdapter):
def __init__(self, ssl_version=None, **kwargs): def __init__(self, ssl_version=None, compress=0, **kwargs):
self._ssl_version = ssl_version self._ssl_version = ssl_version
self._compress = compress
super().__init__(**kwargs) super().__init__(**kwargs)
def init_poolmanager(self, *args, **kwargs): def init_poolmanager(self, *args, **kwargs):
kwargs['ssl_version'] = self._ssl_version kwargs['ssl_version'] = self._ssl_version
super().init_poolmanager(*args, **kwargs) super().init_poolmanager(*args, **kwargs)
def send(self, request: requests.PreparedRequest, **kwargs):
class ContentCompressionHttpAdapter(HTTPAdapter): if self._compress and request.body:
self._compress_body(request, self._compress)
def __init__(self, compress, **kwargs):
self.compress = compress
super().__init__(**kwargs)
def send(self, request, **kwargs):
if request.body and self.compress > 0:
deflater = zlib.compressobj()
if isinstance(request.body, bytes):
deflated_data = deflater.compress(request.body)
else:
deflated_data = deflater.compress(request.body.encode())
deflated_data += deflater.flush()
if len(deflated_data) < len(request.body) or self.compress > 1:
request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data))
return super().send(request, **kwargs) return super().send(request, **kwargs)
@staticmethod
def _compress_body(request: requests.PreparedRequest, compress: int):
deflater = zlib.compressobj()
if isinstance(request.body, bytes):
deflated_data = deflater.compress(request.body)
else:
deflated_data = deflater.compress(request.body.encode())
deflated_data += deflater.flush()
if len(deflated_data) < len(request.body) or compress > 1:
request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data))
def get_requests_session(ssl_version, compress):
def get_requests_session(ssl_version: str, compress: int) -> requests.Session:
requests_session = requests.Session() requests_session = requests.Session()
requests_session.mount( adapter = HTTPieHTTPAdapter(ssl_version=ssl_version, compress=compress)
'https://', for prefix in ['http://', 'https://']:
HTTPieHTTPAdapter(ssl_version=ssl_version) requests_session.mount(prefix, adapter)
)
if compress:
adapter = ContentCompressionHttpAdapter(compress)
for prefix in ['http://', 'https://']:
requests_session.mount(prefix, adapter)
for cls in plugin_manager.get_transport_plugins(): for cls in plugin_manager.get_transport_plugins():
transport_plugin = cls() transport_plugin = cls()
requests_session.mount(prefix=transport_plugin.prefix, requests_session.mount(prefix=transport_plugin.prefix,
@ -95,7 +90,10 @@ def get_requests_session(ssl_version, compress):
return requests_session return requests_session
def get_response(args, config_dir): def get_response(
args: argparse.Namespace,
config_dir: Path
) -> requests.Response:
"""Send the request and return a `request.Response`.""" """Send the request and return a `request.Response`."""
ssl_version = None ssl_version = None
@ -123,32 +121,29 @@ def get_response(args, config_dir):
return response return response
def dump_request(kwargs): def dump_request(kwargs: dict):
sys.stderr.write('\n>>> requests.request(**%s)\n\n' sys.stderr.write('\n>>> requests.request(**%s)\n\n'
% repr_dict_nice(kwargs)) % repr_dict_nice(kwargs))
def finalize_headers(headers): def finalize_headers(headers: RequestHeadersDict) -> RequestHeadersDict:
final_headers = {} final_headers = RequestHeadersDict()
for name, value in headers.items(): for name, value in headers.items():
if value is not None: if value is not None:
# >leading or trailing LWS MAY be removed without # >leading or trailing LWS MAY be removed without
# >changing the semantics of the field value" # >changing the semantics of the field value"
# -https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html # -https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
# Also, requests raises `InvalidHeader` for leading spaces. # Also, requests raises `InvalidHeader` for leading spaces.
value = value.strip() value = value.strip()
if isinstance(value, str): if isinstance(value, str):
# See: https://github.com/jakubroztocil/httpie/issues/212 # See: https://github.com/jakubroztocil/httpie/issues/212
value = value.encode('utf8') value = value.encode('utf8')
final_headers[name] = value final_headers[name] = value
return final_headers return final_headers
def get_default_headers(args): def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict:
default_headers = CaseInsensitiveDict({ default_headers = RequestHeadersDict({
'User-Agent': DEFAULT_UA 'User-Agent': DEFAULT_UA
}) })
@ -165,7 +160,7 @@ def get_default_headers(args):
return default_headers return default_headers
def get_requests_kwargs(args, base_headers=None): def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict:
""" """
Translate our `args` into `requests.request` keyword arguments. Translate our `args` into `requests.request` keyword arguments.