mirror of
https://github.com/httpie/cli.git
synced 2024-11-22 15:53:13 +01:00
Fix --ssl with --compress; refactor client
This commit is contained in:
parent
aba3b1ec01
commit
224519e0e2
@ -1,19 +1,20 @@
|
|||||||
|
import argparse
|
||||||
|
import http.client
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
import zlib
|
||||||
import http.client
|
|
||||||
import requests
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from requests.adapters import HTTPAdapter
|
from pathlib import Path
|
||||||
from requests.structures import CaseInsensitiveDict
|
|
||||||
|
|
||||||
from httpie import sessions
|
import requests
|
||||||
from httpie import __version__
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
from httpie import __version__, sessions
|
||||||
from httpie.cli.constants import SSL_VERSION_ARG_MAPPING
|
from httpie.cli.constants import SSL_VERSION_ARG_MAPPING
|
||||||
|
from httpie.cli.dicts import RequestHeadersDict
|
||||||
from httpie.plugins import plugin_manager
|
from httpie.plugins import plugin_manager
|
||||||
from httpie.utils import repr_dict_nice
|
from httpie.utils import repr_dict_nice
|
||||||
|
|
||||||
import zlib
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# https://urllib3.readthedocs.io/en/latest/security.html
|
# https://urllib3.readthedocs.io/en/latest/security.html
|
||||||
@ -31,7 +32,7 @@ except (ImportError, AttributeError):
|
|||||||
FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8'
|
FORM_CONTENT_TYPE = 'application/x-www-form-urlencoded; charset=utf-8'
|
||||||
JSON_CONTENT_TYPE = 'application/json'
|
JSON_CONTENT_TYPE = 'application/json'
|
||||||
JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*'
|
JSON_ACCEPT = f'{JSON_CONTENT_TYPE}, */*'
|
||||||
DEFAULT_UA = 'HTTPie/%s' % __version__
|
DEFAULT_UA = f'HTTPie/{__version__}'
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
@ -48,46 +49,40 @@ def max_headers(limit):
|
|||||||
|
|
||||||
class HTTPieHTTPAdapter(HTTPAdapter):
|
class HTTPieHTTPAdapter(HTTPAdapter):
|
||||||
|
|
||||||
def __init__(self, ssl_version=None, **kwargs):
|
def __init__(self, ssl_version=None, compress=0, **kwargs):
|
||||||
self._ssl_version = ssl_version
|
self._ssl_version = ssl_version
|
||||||
|
self._compress = compress
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def init_poolmanager(self, *args, **kwargs):
|
def init_poolmanager(self, *args, **kwargs):
|
||||||
kwargs['ssl_version'] = self._ssl_version
|
kwargs['ssl_version'] = self._ssl_version
|
||||||
super().init_poolmanager(*args, **kwargs)
|
super().init_poolmanager(*args, **kwargs)
|
||||||
|
|
||||||
|
def send(self, request: requests.PreparedRequest, **kwargs):
|
||||||
|
if self._compress and request.body:
|
||||||
|
self._compress_body(request, self._compress)
|
||||||
|
return super().send(request, **kwargs)
|
||||||
|
|
||||||
class ContentCompressionHttpAdapter(HTTPAdapter):
|
@staticmethod
|
||||||
|
def _compress_body(request: requests.PreparedRequest, compress: int):
|
||||||
def __init__(self, compress, **kwargs):
|
|
||||||
self.compress = compress
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def send(self, request, **kwargs):
|
|
||||||
if request.body and self.compress > 0:
|
|
||||||
deflater = zlib.compressobj()
|
deflater = zlib.compressobj()
|
||||||
if isinstance(request.body, bytes):
|
if isinstance(request.body, bytes):
|
||||||
deflated_data = deflater.compress(request.body)
|
deflated_data = deflater.compress(request.body)
|
||||||
else:
|
else:
|
||||||
deflated_data = deflater.compress(request.body.encode())
|
deflated_data = deflater.compress(request.body.encode())
|
||||||
deflated_data += deflater.flush()
|
deflated_data += deflater.flush()
|
||||||
if len(deflated_data) < len(request.body) or self.compress > 1:
|
if len(deflated_data) < len(request.body) or compress > 1:
|
||||||
request.body = deflated_data
|
request.body = deflated_data
|
||||||
request.headers['Content-Encoding'] = 'deflate'
|
request.headers['Content-Encoding'] = 'deflate'
|
||||||
request.headers['Content-Length'] = str(len(deflated_data))
|
request.headers['Content-Length'] = str(len(deflated_data))
|
||||||
return super().send(request, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_requests_session(ssl_version, compress):
|
def get_requests_session(ssl_version: str, compress: int) -> requests.Session:
|
||||||
requests_session = requests.Session()
|
requests_session = requests.Session()
|
||||||
requests_session.mount(
|
adapter = HTTPieHTTPAdapter(ssl_version=ssl_version, compress=compress)
|
||||||
'https://',
|
|
||||||
HTTPieHTTPAdapter(ssl_version=ssl_version)
|
|
||||||
)
|
|
||||||
if compress:
|
|
||||||
adapter = ContentCompressionHttpAdapter(compress)
|
|
||||||
for prefix in ['http://', 'https://']:
|
for prefix in ['http://', 'https://']:
|
||||||
requests_session.mount(prefix, adapter)
|
requests_session.mount(prefix, adapter)
|
||||||
|
|
||||||
for cls in plugin_manager.get_transport_plugins():
|
for cls in plugin_manager.get_transport_plugins():
|
||||||
transport_plugin = cls()
|
transport_plugin = cls()
|
||||||
requests_session.mount(prefix=transport_plugin.prefix,
|
requests_session.mount(prefix=transport_plugin.prefix,
|
||||||
@ -95,7 +90,10 @@ def get_requests_session(ssl_version, compress):
|
|||||||
return requests_session
|
return requests_session
|
||||||
|
|
||||||
|
|
||||||
def get_response(args, config_dir):
|
def get_response(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
config_dir: Path
|
||||||
|
) -> requests.Response:
|
||||||
"""Send the request and return a `request.Response`."""
|
"""Send the request and return a `request.Response`."""
|
||||||
|
|
||||||
ssl_version = None
|
ssl_version = None
|
||||||
@ -123,32 +121,29 @@ def get_response(args, config_dir):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def dump_request(kwargs):
|
def dump_request(kwargs: dict):
|
||||||
sys.stderr.write('\n>>> requests.request(**%s)\n\n'
|
sys.stderr.write('\n>>> requests.request(**%s)\n\n'
|
||||||
% repr_dict_nice(kwargs))
|
% repr_dict_nice(kwargs))
|
||||||
|
|
||||||
|
|
||||||
def finalize_headers(headers):
|
def finalize_headers(headers: RequestHeadersDict) -> RequestHeadersDict:
|
||||||
final_headers = {}
|
final_headers = RequestHeadersDict()
|
||||||
for name, value in headers.items():
|
for name, value in headers.items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
|
|
||||||
# >leading or trailing LWS MAY be removed without
|
# >leading or trailing LWS MAY be removed without
|
||||||
# >changing the semantics of the field value"
|
# >changing the semantics of the field value"
|
||||||
# -https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
|
# -https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
|
||||||
# Also, requests raises `InvalidHeader` for leading spaces.
|
# Also, requests raises `InvalidHeader` for leading spaces.
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
|
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
# See: https://github.com/jakubroztocil/httpie/issues/212
|
# See: https://github.com/jakubroztocil/httpie/issues/212
|
||||||
value = value.encode('utf8')
|
value = value.encode('utf8')
|
||||||
|
|
||||||
final_headers[name] = value
|
final_headers[name] = value
|
||||||
return final_headers
|
return final_headers
|
||||||
|
|
||||||
|
|
||||||
def get_default_headers(args):
|
def get_default_headers(args: argparse.Namespace) -> RequestHeadersDict:
|
||||||
default_headers = CaseInsensitiveDict({
|
default_headers = RequestHeadersDict({
|
||||||
'User-Agent': DEFAULT_UA
|
'User-Agent': DEFAULT_UA
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -165,7 +160,7 @@ def get_default_headers(args):
|
|||||||
return default_headers
|
return default_headers
|
||||||
|
|
||||||
|
|
||||||
def get_requests_kwargs(args, base_headers=None):
|
def get_requests_kwargs(args: argparse.Namespace, base_headers=None) -> dict:
|
||||||
"""
|
"""
|
||||||
Translate our `args` into `requests.request` keyword arguments.
|
Translate our `args` into `requests.request` keyword arguments.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user