Python 3 annotations, super(), pathlib, etc.

This commit is contained in:
Jakub Roztocil 2019-08-30 11:32:14 +02:00
parent 63df735fef
commit 0f654388fc
23 changed files with 229 additions and 196 deletions

View File

@ -2,12 +2,15 @@
HTTPie - a CLI, cURL-like tool for humans. HTTPie - a CLI, cURL-like tool for humans.
""" """
from enum import Enum
__version__ = '2.0.0-dev' __version__ = '2.0.0-dev'
__author__ = 'Jakub Roztocil' __author__ = 'Jakub Roztocil'
__licence__ = 'BSD' __licence__ = 'BSD'
class ExitStatus: class ExitStatus(Enum):
"""Program exit code constants.""" """Program exit code constants."""
SUCCESS = 0 SUCCESS = 0
ERROR = 1 ERROR = 1
@ -23,10 +26,3 @@ class ExitStatus:
ERROR_HTTP_3XX = 3 ERROR_HTTP_3XX = 3
ERROR_HTTP_4XX = 4 ERROR_HTTP_4XX = 4
ERROR_HTTP_5XX = 5 ERROR_HTTP_5XX = 5
EXIT_STATUS_LABELS = {
value: key
for key, value in ExitStatus.__dict__.items()
if key.isupper()
}

View File

@ -8,10 +8,12 @@ import sys
def main(): def main():
try: try:
from .core import main from .core import main
sys.exit(main()) exit_status = main()
except KeyboardInterrupt: except KeyboardInterrupt:
from . import ExitStatus from . import ExitStatus
sys.exit(ExitStatus.ERROR_CTRL_C) exit_status = ExitStatus.ERROR_CTRL_C
sys.exit(exit_status.value)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,9 +1,7 @@
"""CLI arguments definition. """
CLI arguments definition.
NOTE: the CLI interface may change before reaching v1.0.
""" """
# noinspection PyCompatibility
from argparse import ( from argparse import (
RawDescriptionHelpFormatter, FileType, RawDescriptionHelpFormatter, FileType,
OPTIONAL, ZERO_OR_MORE, SUPPRESS OPTIONAL, ZERO_OR_MORE, SUPPRESS
@ -40,7 +38,7 @@ class HTTPieHelpFormatter(RawDescriptionHelpFormatter):
def __init__(self, max_help_position=6, *args, **kwargs): def __init__(self, max_help_position=6, *args, **kwargs):
# A smaller indent for args help. # A smaller indent for args help.
kwargs['max_help_position'] = max_help_position kwargs['max_help_position'] = max_help_position
super(HTTPieHelpFormatter, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _split_lines(self, text, width): def _split_lines(self, text, width):
text = dedent(text).strip() + '\n\n' text = dedent(text).strip() + '\n\n'
@ -457,7 +455,7 @@ auth.add_argument(
) )
class _AuthTypeLazyChoices(object): class _AuthTypeLazyChoices:
# Needed for plugin testing # Needed for plugin testing
def __contains__(self, item): def __contains__(self, item):

View File

@ -50,18 +50,18 @@ class HTTPieHTTPAdapter(HTTPAdapter):
def __init__(self, ssl_version=None, **kwargs): def __init__(self, ssl_version=None, **kwargs):
self._ssl_version = ssl_version self._ssl_version = ssl_version
super(HTTPieHTTPAdapter, self).__init__(**kwargs) super().__init__(**kwargs)
def init_poolmanager(self, *args, **kwargs): def init_poolmanager(self, *args, **kwargs):
kwargs['ssl_version'] = self._ssl_version kwargs['ssl_version'] = self._ssl_version
super(HTTPieHTTPAdapter, self).init_poolmanager(*args, **kwargs) super().init_poolmanager(*args, **kwargs)
class ContentCompressionHttpAdapter(HTTPAdapter): class ContentCompressionHttpAdapter(HTTPAdapter):
def __init__(self, compress, **kwargs): def __init__(self, compress, **kwargs):
self.compress = compress self.compress = compress
super(ContentCompressionHttpAdapter, self).__init__(**kwargs) super().__init__(**kwargs)
def send(self, request, **kwargs): def send(self, request, **kwargs):
if request.body and self.compress > 0: if request.body and self.compress > 0:
@ -75,7 +75,7 @@ class ContentCompressionHttpAdapter(HTTPAdapter):
request.body = deflated_data request.body = deflated_data
request.headers['Content-Encoding'] = 'deflate' request.headers['Content-Encoding'] = 'deflate'
request.headers['Content-Length'] = str(len(deflated_data)) request.headers['Content-Length'] = str(len(deflated_data))
return super(ContentCompressionHttpAdapter, self).send(request, **kwargs) return super().send(request, **kwargs)
def get_requests_session(ssl_version, compress): def get_requests_session(ssl_version, compress):

View File

@ -1,12 +1,14 @@
import os
import json
import errno import errno
import json
import os
from pathlib import Path
from typing import Union
from httpie import __version__ from httpie import __version__
from httpie.compat import is_windows from httpie.compat import is_windows
DEFAULT_CONFIG_DIR = str(os.environ.get( DEFAULT_CONFIG_DIR = Path(os.environ.get(
'HTTPIE_CONFIG_DIR', 'HTTPIE_CONFIG_DIR',
os.path.expanduser('~/.httpie') if not is_windows else os.path.expanduser('~/.httpie') if not is_windows else
os.path.expandvars(r'%APPDATA%\\httpie') os.path.expandvars(r'%APPDATA%\\httpie')
@ -14,41 +16,36 @@ DEFAULT_CONFIG_DIR = str(os.environ.get(
class BaseConfigDict(dict): class BaseConfigDict(dict):
name = None name = None
helpurl = None helpurl = None
about = None about = None
def __getattr__(self, item): def _get_path(self) -> Path:
return self[item]
def _get_path(self):
"""Return the config file path without side-effects.""" """Return the config file path without side-effects."""
raise NotImplementedError() raise NotImplementedError()
@property def path(self) -> Path:
def path(self):
"""Return the config file path creating basedir, if needed.""" """Return the config file path creating basedir, if needed."""
path = self._get_path() path = self._get_path()
try: try:
os.makedirs(os.path.dirname(path), mode=0o700) path.parent.mkdir(mode=0o700, parents=True)
except OSError as e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
return path return path
def is_new(self): def is_new(self) -> bool:
return not os.path.exists(self._get_path()) return not self._get_path().exists()
def load(self): def load(self):
try: try:
with open(self.path, 'rt') as f: with self.path().open('rt') as f:
try: try:
data = json.load(f) data = json.load(f)
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
'Invalid %s JSON: %s [%s]' % 'Invalid %s JSON: %s [%s]' %
(type(self).__name__, str(e), self.path) (type(self).__name__, str(e), self.path())
) )
self.update(data) self.update(data)
except IOError as e: except IOError as e:
@ -66,7 +63,7 @@ class BaseConfigDict(dict):
self['__meta__']['about'] = self.about self['__meta__']['about'] = self.about
try: try:
with open(self.path, 'w') as f: with self.path().open('w') as f:
json.dump(self, f, indent=4, sort_keys=True, ensure_ascii=True) json.dump(self, f, indent=4, sort_keys=True, ensure_ascii=True)
f.write('\n') f.write('\n')
except IOError: except IOError:
@ -75,26 +72,28 @@ class BaseConfigDict(dict):
def delete(self): def delete(self):
try: try:
os.unlink(self.path) self.path().unlink()
except OSError as e: except OSError as e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
class Config(BaseConfigDict): class Config(BaseConfigDict):
name = 'config' name = 'config'
helpurl = 'https://httpie.org/doc#config' helpurl = 'https://httpie.org/doc#config'
about = 'HTTPie configuration file' about = 'HTTPie configuration file'
DEFAULTS = { DEFAULTS = {
'default_options': [] 'default_options': []
} }
def __init__(self, directory=DEFAULT_CONFIG_DIR): def __init__(self, directory: Union[str, Path] = DEFAULT_CONFIG_DIR):
super(Config, self).__init__() super().__init__()
self.update(self.DEFAULTS) self.update(self.DEFAULTS)
self.directory = directory self.directory = Path(directory)
def _get_path(self): def _get_path(self) -> Path:
return os.path.join(self.directory, self.name + '.json') return self.directory / (self.name + '.json')
@property
def default_options(self) -> list:
return self['default_options']

View File

@ -1,4 +1,5 @@
import sys import sys
from pathlib import Path
from typing import Union, IO, Optional from typing import Union, IO, Optional
@ -13,7 +14,7 @@ from httpie.config import DEFAULT_CONFIG_DIR, Config
from httpie.utils import repr_dict_nice from httpie.utils import repr_dict_nice
class Environment(object): class Environment:
""" """
Information about the execution context Information about the execution context
(standard streams, config directory, etc). (standard streams, config directory, etc).
@ -23,16 +24,16 @@ class Environment(object):
is used by the test suite to simulate various scenarios. is used by the test suite to simulate various scenarios.
""" """
is_windows = is_windows is_windows: bool = is_windows
config_dir = DEFAULT_CONFIG_DIR config_dir: Path = DEFAULT_CONFIG_DIR
stdin: Optional[IO] = sys.stdin # `None` when closed fd (#791) stdin: Optional[IO] = sys.stdin # `None` when closed fd (#791)
stdin_isatty = stdin.isatty() if stdin else False stdin_isatty: bool = stdin.isatty() if stdin else False
stdin_encoding = None stdin_encoding: str = None
stdout = sys.stdout stdout: IO = sys.stdout
stdout_isatty = stdout.isatty() stdout_isatty: bool = stdout.isatty()
stdout_encoding = None stdout_encoding: str = None
stderr = sys.stderr stderr: IO = sys.stderr
stderr_isatty = stderr.isatty() stderr_isatty: bool = stderr.isatty()
colors = 256 colors = 256
if not is_windows: if not is_windows:
if curses: if curses:
@ -73,12 +74,13 @@ class Environment(object):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from colorama import AnsiToWin32 from colorama import AnsiToWin32
if isinstance(self.stdout, AnsiToWin32): if isinstance(self.stdout, AnsiToWin32):
# noinspection PyUnresolvedReferences
actual_stdout = self.stdout.wrapped actual_stdout = self.stdout.wrapped
self.stdout_encoding = getattr( self.stdout_encoding = getattr(
actual_stdout, 'encoding', None) or 'utf8' actual_stdout, 'encoding', None) or 'utf8'
@property @property
def config(self): def config(self) -> Config:
if not hasattr(self, '_config'): if not hasattr(self, '_config'):
self._config = Config(directory=self.config_dir) self._config = Config(directory=self.config_dir)
if self._config.is_new(): if self._config.is_new():

View File

@ -10,27 +10,29 @@ Invocation flow:
5. Exit. 5. Exit.
""" """
import sys import argparse
import errno import errno
import platform import platform
import sys
from typing import Callable, List, Union
import requests import requests
from requests import __version__ as requests_version
from pygments import __version__ as pygments_version from pygments import __version__ as pygments_version
from requests import __version__ as requests_version
from httpie import __version__ as httpie_version, ExitStatus from httpie import ExitStatus, __version__ as httpie_version
from httpie.client import get_response from httpie.client import get_response
from httpie.downloads import Downloader
from httpie.context import Environment from httpie.context import Environment
from httpie.plugins import plugin_manager from httpie.downloads import Downloader
from httpie.output.streams import ( from httpie.output.streams import (
build_output_stream, build_output_stream,
write_stream, write_stream,
write_stream_with_colors_win_py3 write_stream_with_colors_win_py3,
) )
from httpie.plugins import plugin_manager
def get_exit_status(http_status, follow=False): def get_exit_status(http_status: int, follow=False) -> ExitStatus:
"""Translate HTTP status code to exit status code.""" """Translate HTTP status code to exit status code."""
if 300 <= http_status <= 399 and not follow: if 300 <= http_status <= 399 and not follow:
# Redirect # Redirect
@ -45,7 +47,7 @@ def get_exit_status(http_status, follow=False):
return ExitStatus.SUCCESS return ExitStatus.SUCCESS
def print_debug_info(env): def print_debug_info(env: Environment):
env.stderr.writelines([ env.stderr.writelines([
'HTTPie %s\n' % httpie_version, 'HTTPie %s\n' % httpie_version,
'Requests %s\n' % requests_version, 'Requests %s\n' % requests_version,
@ -58,7 +60,10 @@ def print_debug_info(env):
env.stderr.write('\n') env.stderr.write('\n')
def decode_args(args, stdin_encoding): def decode_args(
args: List[Union[str, bytes]],
stdin_encoding: str
) -> List[str]:
""" """
Convert all bytes args to str Convert all bytes args to str
by decoding them using stdin encoding. by decoding them using stdin encoding.
@ -71,7 +76,11 @@ def decode_args(args, stdin_encoding):
] ]
def program(args, env, log_error): def program(
args: argparse.Namespace,
env: Environment,
log_error: Callable
) -> ExitStatus:
""" """
The main program without error handling The main program without error handling
@ -168,7 +177,11 @@ def program(args, env, log_error):
args.output_file.close() args.output_file.close()
def main(args=sys.argv, env=Environment(), custom_log_error=None): def main(
args: List[Union[str, bytes]] = sys.argv,
env=Environment(),
custom_log_error: Callable = None
) -> ExitStatus:
""" """
The main function. The main function.
@ -218,7 +231,7 @@ def main(args=sys.argv, env=Environment(), custom_log_error=None):
raise raise
exit_status = ExitStatus.ERROR_CTRL_C exit_status = ExitStatus.ERROR_CTRL_C
except SystemExit as e: except SystemExit as e:
if e.code != ExitStatus.SUCCESS: if e.code != ExitStatus.SUCCESS.value:
env.stderr.write('\n') env.stderr.write('\n')
if include_traceback: if include_traceback:
raise raise
@ -236,7 +249,7 @@ def main(args=sys.argv, env=Environment(), custom_log_error=None):
raise raise
exit_status = ExitStatus.ERROR_CTRL_C exit_status = ExitStatus.ERROR_CTRL_C
except SystemExit as e: except SystemExit as e:
if e.code != ExitStatus.SUCCESS: if e.code != ExitStatus.SUCCESS.value:
env.stderr.write('\n') env.stderr.write('\n')
if include_traceback: if include_traceback:
raise raise

View File

@ -4,24 +4,27 @@ Download mode implementation.
""" """
from __future__ import division from __future__ import division
import errno
import mimetypes
import os import os
import re import re
import sys import sys
import errno
import mimetypes
import threading import threading
from time import sleep, time
from mailbox import Message from mailbox import Message
from time import sleep, time
from typing import IO, Optional, Tuple
from urllib.parse import urlsplit from urllib.parse import urlsplit
from httpie.output.streams import RawStream import requests
from httpie.models import HTTPResponse from httpie.models import HTTPResponse
from httpie.output.streams import RawStream
from httpie.utils import humanize_bytes from httpie.utils import humanize_bytes
PARTIAL_CONTENT = 206 PARTIAL_CONTENT = 206
CLEAR_LINE = '\r\033[K' CLEAR_LINE = '\r\033[K'
PROGRESS = ( PROGRESS = (
'{percentage: 6.2f} %' '{percentage: 6.2f} %'
@ -38,7 +41,7 @@ class ContentRangeError(ValueError):
pass pass
def parse_content_range(content_range, resumed_from): def parse_content_range(content_range: str, resumed_from: int) -> int:
""" """
Parse and validate Content-Range header. Parse and validate Content-Range header.
@ -97,7 +100,9 @@ def parse_content_range(content_range, resumed_from):
return last_byte_pos + 1 return last_byte_pos + 1
def filename_from_content_disposition(content_disposition): def filename_from_content_disposition(
content_disposition: str
) -> Optional[str]:
""" """
Extract and validate filename from a Content-Disposition header. Extract and validate filename from a Content-Disposition header.
@ -116,7 +121,7 @@ def filename_from_content_disposition(content_disposition):
return filename return filename
def filename_from_url(url, content_type): def filename_from_url(url: str, content_type: str) -> str:
fn = urlsplit(url).path.rstrip('/') fn = urlsplit(url).path.rstrip('/')
fn = os.path.basename(fn) if fn else 'index' fn = os.path.basename(fn) if fn else 'index'
if '.' not in fn and content_type: if '.' not in fn and content_type:
@ -136,7 +141,7 @@ def filename_from_url(url, content_type):
return fn return fn
def trim_filename(filename, max_len): def trim_filename(filename: str, max_len: int) -> str:
if len(filename) > max_len: if len(filename) > max_len:
trim_by = len(filename) - max_len trim_by = len(filename) - max_len
name, ext = os.path.splitext(filename) name, ext = os.path.splitext(filename)
@ -147,7 +152,7 @@ def trim_filename(filename, max_len):
return filename return filename
def get_filename_max_length(directory): def get_filename_max_length(directory: str) -> int:
max_len = 255 max_len = 255
try: try:
pathconf = os.pathconf pathconf = os.pathconf
@ -162,14 +167,14 @@ def get_filename_max_length(directory):
return max_len return max_len
def trim_filename_if_needed(filename, directory='.', extra=0): def trim_filename_if_needed(filename: str, directory='.', extra=0) -> str:
max_len = get_filename_max_length(directory) - extra max_len = get_filename_max_length(directory) - extra
if len(filename) > max_len: if len(filename) > max_len:
filename = trim_filename(filename, max_len) filename = trim_filename(filename, max_len)
return filename return filename
def get_unique_filename(filename, exists=os.path.exists): def get_unique_filename(filename: str, exists=os.path.exists) -> str:
attempt = 0 attempt = 0
while True: while True:
suffix = '-' + str(attempt) if attempt > 0 else '' suffix = '-' + str(attempt) if attempt > 0 else ''
@ -180,10 +185,14 @@ def get_unique_filename(filename, exists=os.path.exists):
attempt += 1 attempt += 1
class Downloader(object): class Downloader:
def __init__(self, output_file=None, def __init__(
resume=False, progress_file=sys.stderr): self,
output_file: IO = None,
resume: bool = False,
progress_file: IO = sys.stderr
):
""" """
:param resume: Should the download resume if partial download :param resume: Should the download resume if partial download
already exists. already exists.
@ -195,24 +204,21 @@ class Downloader(object):
:param progress_file: Where to report download progress. :param progress_file: Where to report download progress.
""" """
self.finished = False
self.status = DownloadStatus()
self._output_file = output_file self._output_file = output_file
self._resume = resume self._resume = resume
self._resumed_from = 0 self._resumed_from = 0
self.finished = False
self.status = Status()
self._progress_reporter = ProgressReporterThread( self._progress_reporter = ProgressReporterThread(
status=self.status, status=self.status,
output=progress_file output=progress_file
) )
def pre_request(self, request_headers): def pre_request(self, request_headers: dict):
"""Called just before the HTTP request is sent. """Called just before the HTTP request is sent.
Might alter `request_headers`. Might alter `request_headers`.
:type request_headers: dict
""" """
# Ask the server not to encode the content so that we can resume, etc. # Ask the server not to encode the content so that we can resume, etc.
request_headers['Accept-Encoding'] = 'identity' request_headers['Accept-Encoding'] = 'identity'
@ -224,13 +230,12 @@ class Downloader(object):
request_headers['Range'] = 'bytes=%d-' % bytes_have request_headers['Range'] = 'bytes=%d-' % bytes_have
self._resumed_from = bytes_have self._resumed_from = bytes_have
def start(self, final_response): def start(self, final_response: requests.Response) -> Tuple[RawStream, IO]:
""" """
Initiate and return a stream for `response` body with progress Initiate and return a stream for `response` body with progress
callback attached. Can be called only once. callback attached. Can be called only once.
:param final_response: Initiated response object with headers already fetched :param final_response: Initiated response object with headers already fetched
:type final_response: requests.models.Response
:return: RawStream, output_file :return: RawStream, output_file
@ -297,14 +302,14 @@ class Downloader(object):
self._progress_reporter.stop() self._progress_reporter.stop()
@property @property
def interrupted(self): def interrupted(self) -> bool:
return ( return (
self.finished self.finished
and self.status.total_size and self.status.total_size
and self.status.total_size != self.status.downloaded and self.status.total_size != self.status.downloaded
) )
def chunk_downloaded(self, chunk): def chunk_downloaded(self, chunk: bytes):
""" """
A download progress callback. A download progress callback.
@ -316,7 +321,9 @@ class Downloader(object):
self.status.chunk_downloaded(len(chunk)) self.status.chunk_downloaded(len(chunk))
@staticmethod @staticmethod
def _get_output_file_from_response(final_response): def _get_output_file_from_response(
final_response: requests.Response
) -> IO:
# Output file not specified. Pick a name that doesn't exist yet. # Output file not specified. Pick a name that doesn't exist yet.
filename = None filename = None
if 'Content-Disposition' in final_response.headers: if 'Content-Disposition' in final_response.headers:
@ -335,7 +342,7 @@ class Downloader(object):
return open(unique_filename, mode='a+b') return open(unique_filename, mode='a+b')
class Status(object): class DownloadStatus:
"""Holds details about the downland status.""" """Holds details about the downland status."""
def __init__(self): def __init__(self):
@ -371,14 +378,16 @@ class ProgressReporterThread(threading.Thread):
Uses threading to periodically update the status (speed, ETA, etc.). Uses threading to periodically update the status (speed, ETA, etc.).
"""
def __init__(self, status, output, tick=.1, update_interval=1):
""" """
:type status: Status def __init__(
:type output: file self,
""" status: DownloadStatus,
super(ProgressReporterThread, self).__init__() output: IO,
tick=.1,
update_interval=1
):
super().__init__()
self.status = status self.status = status
self.output = output self.output = output
self._tick = tick self._tick = tick

View File

@ -9,14 +9,15 @@ import errno
import mimetypes import mimetypes
import getpass import getpass
from io import BytesIO from io import BytesIO
from collections import namedtuple, Iterable, OrderedDict from collections import namedtuple, OrderedDict
# noinspection PyCompatibility # noinspection PyCompatibility
from argparse import ArgumentParser, ArgumentTypeError, ArgumentError import argparse
# TODO: Use MultiDict for headers once added to `requests`. # TODO: Use MultiDict for headers once added to `requests`.
# https://github.com/jakubroztocil/httpie/issues/130 # https://github.com/jakubroztocil/httpie/issues/130
from urllib.parse import urlsplit from urllib.parse import urlsplit
from httpie.context import Environment
from httpie.plugins import plugin_manager from httpie.plugins import plugin_manager
from requests.structures import CaseInsensitiveDict from requests.structures import CaseInsensitiveDict
@ -121,7 +122,7 @@ SSL_VERSION_ARG_MAPPING = {
} }
class HTTPieArgumentParser(ArgumentParser): class HTTPieArgumentParser(argparse.ArgumentParser):
"""Adds additional logic to `argparse.ArgumentParser`. """Adds additional logic to `argparse.ArgumentParser`.
Handles all input (CLI args, file args, stdin), applies defaults, Handles all input (CLI args, file args, stdin), applies defaults,
@ -131,16 +132,21 @@ class HTTPieArgumentParser(ArgumentParser):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['add_help'] = False kwargs['add_help'] = False
super(HTTPieArgumentParser, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.env = None self.env = None
self.args = None self.args = None
self.has_stdin_data = False self.has_stdin_data = False
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
def parse_args(self, env, program_name='http', args=None, namespace=None): def parse_args(
self,
env: Environment,
program_name='http',
args=None,
namespace=None
) -> argparse.Namespace:
self.env = env self.env = env
self.args, no_options = super( self.args, no_options = super().parse_known_args(args, namespace)
HTTPieArgumentParser, self).parse_known_args(args, namespace)
if self.args.debug: if self.args.debug:
self.args.traceback = True self.args.traceback = True
@ -193,7 +199,7 @@ class HTTPieArgumentParser(ArgumentParser):
}.get(file, file) }.get(file, file)
if not hasattr(file, 'buffer') and isinstance(message, str): if not hasattr(file, 'buffer') and isinstance(message, str):
message = message.encode(self.env.stdout_encoding) message = message.encode(self.env.stdout_encoding)
super(HTTPieArgumentParser, self)._print_message(message, file) super()._print_message(message, file)
def _setup_standard_streams(self): def _setup_standard_streams(self):
""" """
@ -342,7 +348,7 @@ class HTTPieArgumentParser(ArgumentParser):
self.args.items.insert(0, KeyValueArgType( self.args.items.insert(0, KeyValueArgType(
*SEP_GROUP_ALL_ITEMS).__call__(self.args.url)) *SEP_GROUP_ALL_ITEMS).__call__(self.args.url))
except ArgumentTypeError as e: except argparse.ArgumentTypeError as e:
if self.args.traceback: if self.args.traceback:
raise raise
self.error(e.args[0]) self.error(e.args[0])
@ -461,7 +467,7 @@ class ParseError(Exception):
pass pass
class KeyValue(object): class KeyValue:
"""Base key-value pair parsed from CLI.""" """Base key-value pair parsed from CLI."""
def __init__(self, key, value, sep, orig): def __init__(self, key, value, sep, orig):
@ -477,7 +483,7 @@ class KeyValue(object):
return repr(self.__dict__) return repr(self.__dict__)
class SessionNameValidator(object): class SessionNameValidator:
def __init__(self, error_message): def __init__(self, error_message):
self.error_message = error_message self.error_message = error_message
@ -486,11 +492,11 @@ class SessionNameValidator(object):
# Session name can be a path or just a name. # Session name can be a path or just a name.
if (os.path.sep not in value if (os.path.sep not in value
and not VALID_SESSION_NAME_PATTERN.search(value)): and not VALID_SESSION_NAME_PATTERN.search(value)):
raise ArgumentError(None, self.error_message) raise argparse.ArgumentError(None, self.error_message)
return value return value
class KeyValueArgType(object): class KeyValueArgType:
"""A key-value pair argument type used with `argparse`. """A key-value pair argument type used with `argparse`.
Parses a key-value arg and constructs a `KeyValue` instance. Parses a key-value arg and constructs a `KeyValue` instance.
@ -573,7 +579,7 @@ class KeyValueArgType(object):
break break
else: else:
raise ArgumentTypeError( raise argparse.ArgumentTypeError(
u'"%s" is not a valid value' % string) u'"%s" is not a valid value' % string)
return self.key_value_class( return self.key_value_class(
@ -611,8 +617,8 @@ class AuthCredentialsArgType(KeyValueArgType):
""" """
try: try:
return super(AuthCredentialsArgType, self).__call__(string) return super().__call__(string)
except ArgumentTypeError: except argparse.ArgumentTypeError:
# No password provided, will prompt for it later. # No password provided, will prompt for it later.
return self.key_value_class( return self.key_value_class(
key=string, key=string,
@ -639,10 +645,10 @@ class RequestItemsDict(OrderedDict):
""" """
assert not isinstance(value, list) assert not isinstance(value, list)
if key not in self: if key not in self:
super(RequestItemsDict, self).__setitem__(key, value) super().__setitem__(key, value)
else: else:
if not isinstance(self[key], list): if not isinstance(self[key], list):
super(RequestItemsDict, self).__setitem__(key, [self[key]]) super().__setitem__(key, [self[key]])
self[key].append(value) self[key].append(value)
@ -653,7 +659,7 @@ class ParamsDict(RequestItemsDict):
class DataDict(RequestItemsDict): class DataDict(RequestItemsDict):
def items(self): def items(self):
for key, values in super(RequestItemsDict, self).items(): for key, values in super().items():
if not isinstance(values, list): if not isinstance(values, list):
values = [values] values = [values]
for value in values: for value in values:
@ -757,4 +763,4 @@ def readable_file_arg(filename):
with open(filename, 'rb'): with open(filename, 'rb'):
return filename return filename
except IOError as ex: except IOError as ex:
raise ArgumentTypeError('%s: %s' % (filename, ex.args[1])) raise argparse.ArgumentTypeError('%s: %s' % (filename, ex.args[1]))

View File

@ -1,37 +1,38 @@
from typing import Iterable, Optional
from urllib.parse import urlsplit from urllib.parse import urlsplit
class HTTPMessage(object): class HTTPMessage:
"""Abstract class for HTTP messages.""" """Abstract class for HTTP messages."""
def __init__(self, orig): def __init__(self, orig):
self._orig = orig self._orig = orig
def iter_body(self, chunk_size): def iter_body(self, chunk_size: int) -> Iterable[bytes]:
"""Return an iterator over the body.""" """Return an iterator over the body."""
raise NotImplementedError() raise NotImplementedError()
def iter_lines(self, chunk_size): def iter_lines(self, chunk_size: int) -> Iterable[bytes]:
"""Return an iterator over the body yielding (`line`, `line_feed`).""" """Return an iterator over the body yielding (`line`, `line_feed`)."""
raise NotImplementedError() raise NotImplementedError()
@property @property
def headers(self): def headers(self) -> str:
"""Return a `str` with the message's headers.""" """Return a `str` with the message's headers."""
raise NotImplementedError() raise NotImplementedError()
@property @property
def encoding(self): def encoding(self) -> Optional[str]:
"""Return a `str` with the message's encoding, if known.""" """Return a `str` with the message's encoding, if known."""
raise NotImplementedError() raise NotImplementedError()
@property @property
def body(self): def body(self) -> bytes:
"""Return a `bytes` with the message's body.""" """Return a `bytes` with the message's body."""
raise NotImplementedError() raise NotImplementedError()
@property @property
def content_type(self): def content_type(self) -> str:
"""Return the message content type.""" """Return the message content type."""
ct = self._orig.headers.get('Content-Type', '') ct = self._orig.headers.get('Content-Type', '')
if not isinstance(ct, str): if not isinstance(ct, str):

View File

@ -42,7 +42,7 @@ class ColorFormatter(FormatterPlugin):
def __init__(self, env, explicit_json=False, def __init__(self, env, explicit_json=False,
color_scheme=DEFAULT_STYLE, **kwargs): color_scheme=DEFAULT_STYLE, **kwargs):
super(ColorFormatter, self).__init__(**kwargs) super().__init__(**kwargs)
if not env.colors: if not env.colors:
self.enabled = False self.enabled = False

View File

@ -11,7 +11,7 @@ def is_valid_mime(mime):
return mime and MIME_RE.match(mime) return mime and MIME_RE.match(mime)
class Conversion(object): class Conversion:
def get_converter(self, mime): def get_converter(self, mime):
if is_valid_mime(mime): if is_valid_mime(mime):
@ -20,7 +20,7 @@ class Conversion(object):
return converter_class(mime) return converter_class(mime)
class Formatting(object): class Formatting:
"""A delegate class that invokes the actual processors.""" """A delegate class that invokes the actual processors."""
def __init__(self, groups, env=Environment(), **kwargs): def __init__(self, groups, env=Environment(), **kwargs):

View File

@ -124,7 +124,7 @@ def get_stream_type(env, args):
return Stream return Stream
class BaseStream(object): class BaseStream:
"""Base HTTP message output stream class.""" """Base HTTP message output stream class."""
def __init__(self, msg, with_headers=True, with_body=True, def __init__(self, msg, with_headers=True, with_body=True,
@ -174,7 +174,7 @@ class RawStream(BaseStream):
CHUNK_SIZE_BY_LINE = 1 CHUNK_SIZE_BY_LINE = 1
def __init__(self, chunk_size=CHUNK_SIZE, **kwargs): def __init__(self, chunk_size=CHUNK_SIZE, **kwargs):
super(RawStream, self).__init__(**kwargs) super().__init__(**kwargs)
self.chunk_size = chunk_size self.chunk_size = chunk_size
def iter_body(self): def iter_body(self):
@ -193,7 +193,7 @@ class EncodedStream(BaseStream):
def __init__(self, env=Environment(), **kwargs): def __init__(self, env=Environment(), **kwargs):
super(EncodedStream, self).__init__(**kwargs) super().__init__(**kwargs)
if env.stdout_isatty: if env.stdout_isatty:
# Use the encoding supported by the terminal. # Use the encoding supported by the terminal.
@ -228,7 +228,7 @@ class PrettyStream(EncodedStream):
CHUNK_SIZE = 1 CHUNK_SIZE = 1
def __init__(self, conversion, formatting, **kwargs): def __init__(self, conversion, formatting, **kwargs):
super(PrettyStream, self).__init__(**kwargs) super().__init__(**kwargs)
self.formatting = formatting self.formatting = formatting
self.conversion = conversion self.conversion = conversion
self.mime = self.msg.content_type.split(';')[0] self.mime = self.msg.content_type.split(';')[0]

View File

@ -1,4 +1,4 @@
class BasePlugin(object): class BasePlugin:
# The name of the plugin, eg. "My auth". # The name of the plugin, eg. "My auth".
name = None name = None
@ -75,7 +75,7 @@ class TransportPlugin(BasePlugin):
raise NotImplementedError() raise NotImplementedError()
class ConverterPlugin(object): class ConverterPlugin:
def __init__(self, mime): def __init__(self, mime):
self.mime = mime self.mime = mime
@ -88,7 +88,7 @@ class ConverterPlugin(object):
raise NotImplementedError raise NotImplementedError
class FormatterPlugin(object): class FormatterPlugin:
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """

View File

@ -12,7 +12,7 @@ ENTRY_POINT_NAMES = [
] ]
class PluginManager(object): class PluginManager:
def __init__(self): def __init__(self):
self._plugins = [] self._plugins = []

View File

@ -3,16 +3,20 @@
""" """
import re import re
import os import os
from pathlib import Path
from typing import Optional, Union
from urllib.parse import urlsplit from urllib.parse import urlsplit
from requests.auth import AuthBase
from requests.cookies import RequestsCookieJar, create_cookie from requests.cookies import RequestsCookieJar, create_cookie
import requests
from httpie.config import BaseConfigDict, DEFAULT_CONFIG_DIR from httpie.config import BaseConfigDict, DEFAULT_CONFIG_DIR
from httpie.plugins import plugin_manager from httpie.plugins import plugin_manager
SESSIONS_DIR_NAME = 'sessions' SESSIONS_DIR_NAME = 'sessions'
DEFAULT_SESSIONS_DIR = os.path.join(DEFAULT_CONFIG_DIR, SESSIONS_DIR_NAME) DEFAULT_SESSIONS_DIR = DEFAULT_CONFIG_DIR / SESSIONS_DIR_NAME
VALID_SESSION_NAME_PATTERN = re.compile('^[a-zA-Z0-9_.-]+$') VALID_SESSION_NAME_PATTERN = re.compile('^[a-zA-Z0-9_.-]+$')
# Request headers starting with these prefixes won't be stored in sessions. # Request headers starting with these prefixes won't be stored in sessions.
# They are specific to each request. # They are specific to each request.
@ -20,8 +24,13 @@ VALID_SESSION_NAME_PATTERN = re.compile('^[a-zA-Z0-9_.-]+$')
SESSION_IGNORED_HEADER_PREFIXES = ['Content-', 'If-'] SESSION_IGNORED_HEADER_PREFIXES = ['Content-', 'If-']
def get_response(requests_session, session_name, def get_response(
config_dir, args, read_only=False): requests_session: requests.Session,
session_name: str,
config_dir: Path,
args,
read_only=False,
) -> requests.Response:
"""Like `client.get_responses`, but applies permanent """Like `client.get_responses`, but applies permanent
aspects of the session to the request. aspects of the session to the request.
@ -38,10 +47,10 @@ def get_response(requests_session, session_name,
# host:port => host_port # host:port => host_port
hostname = hostname.replace(':', '_') hostname = hostname.replace(':', '_')
path = os.path.join(config_dir, path = (
SESSIONS_DIR_NAME, config_dir / SESSIONS_DIR_NAME / hostname /
hostname, (session_name + '.json')
session_name + '.json') )
session = Session(path) session = Session(path)
session.load() session.load()
@ -77,9 +86,9 @@ class Session(BaseConfigDict):
helpurl = 'https://httpie.org/doc#sessions' helpurl = 'https://httpie.org/doc#sessions'
about = 'HTTPie session file' about = 'HTTPie session file'
def __init__(self, path, *args, **kwargs): def __init__(self, path: Union[str, Path]):
super(Session, self).__init__(*args, **kwargs) super().__init__()
self._path = path self._path = Path(path)
self['headers'] = {} self['headers'] = {}
self['cookies'] = {} self['cookies'] = {}
self['auth'] = { self['auth'] = {
@ -88,10 +97,10 @@ class Session(BaseConfigDict):
'password': None 'password': None
} }
def _get_path(self): def _get_path(self) -> Path:
return self._path return self._path
def update_headers(self, request_headers): def update_headers(self, request_headers: dict):
""" """
Update the session headers with the request ones while ignoring Update the session headers with the request ones while ignoring
certain name prefixes. certain name prefixes.
@ -102,7 +111,7 @@ class Session(BaseConfigDict):
for name, value in request_headers.items(): for name, value in request_headers.items():
if value is None: if value is None:
continue # Ignore explicitely unset headers continue # Ignore explicitly unset headers
value = value.decode('utf8') value = value.decode('utf8')
if name == 'User-Agent' and value.startswith('HTTPie/'): if name == 'User-Agent' and value.startswith('HTTPie/'):
@ -115,11 +124,11 @@ class Session(BaseConfigDict):
self['headers'][name] = value self['headers'][name] = value
@property @property
def headers(self): def headers(self) -> dict:
return self['headers'] return self['headers']
@property @property
def cookies(self): def cookies(self) -> RequestsCookieJar:
jar = RequestsCookieJar() jar = RequestsCookieJar()
for name, cookie_dict in self['cookies'].items(): for name, cookie_dict in self['cookies'].items():
jar.set_cookie(create_cookie( jar.set_cookie(create_cookie(
@ -128,10 +137,7 @@ class Session(BaseConfigDict):
return jar return jar
@cookies.setter @cookies.setter
def cookies(self, jar): def cookies(self, jar: RequestsCookieJar):
"""
:type jar: CookieJar
"""
# https://docs.python.org/2/library/cookielib.html#cookie-objects # https://docs.python.org/2/library/cookielib.html#cookie-objects
stored_attrs = ['value', 'path', 'secure', 'expires'] stored_attrs = ['value', 'path', 'secure', 'expires']
self['cookies'] = {} self['cookies'] = {}
@ -142,7 +148,7 @@ class Session(BaseConfigDict):
} }
@property @property
def auth(self): def auth(self) -> Optional[AuthBase]:
auth = self.get('auth', None) auth = self.get('auth', None)
if not auth or not auth['type']: if not auth or not auth['type']:
return return
@ -171,6 +177,6 @@ class Session(BaseConfigDict):
return plugin.get_auth(**credentials) return plugin.get_auth(**credentials)
@auth.setter @auth.setter
def auth(self, auth): def auth(self, auth: dict):
assert {'type', 'raw_auth'} == auth.keys() assert {'type', 'raw_auth'} == auth.keys()
self['auth'] = auth self['auth'] = auth

View File

@ -305,7 +305,7 @@ class TestNoOptions:
def test_invalid_no_options(self, httpbin): def test_invalid_no_options(self, httpbin):
r = http('--no-war', 'GET', httpbin.url + '/get', r = http('--no-war', 'GET', httpbin.url + '/get',
error_exit_ok=True) error_exit_ok=True)
assert r.exit_status == 1 assert r.exit_status == ExitStatus.ERROR
assert 'unrecognized arguments: --no-war' in r.stderr assert 'unrecognized arguments: --no-war' in r.stderr
assert 'GET /get HTTP/1.1' not in r assert 'GET /get HTTP/1.1' not in r

View File

@ -28,5 +28,5 @@ def test_default_options_overwrite(httpbin):
def test_current_version(): def test_current_version():
version = Environment().config['__meta__']['httpie'] version = MockEnvironment().config['__meta__']['httpie']
assert version == __version__ assert version == __version__

View File

@ -9,7 +9,7 @@ from utils import TESTS_ROOT
def has_docutils(): def has_docutils():
try: try:
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences,PyPackageRequirements
import docutils import docutils
return True return True
except ImportError: except ImportError:
@ -17,6 +17,7 @@ def has_docutils():
def rst_filenames(): def rst_filenames():
# noinspection PyShadowingNames
for root, dirnames, filenames in os.walk(os.path.dirname(TESTS_ROOT)): for root, dirnames, filenames in os.walk(os.path.dirname(TESTS_ROOT)):
if '.tox' not in root: if '.tox' not in root:
for filename in fnmatch.filter(filenames, '*.rst'): for filename in fnmatch.filter(filenames, '*.rst'):

View File

@ -14,7 +14,7 @@ from httpie.downloads import (
from utils import http, MockEnvironment from utils import http, MockEnvironment
class Response(object): class Response:
# noinspection PyDefaultArgument # noinspection PyDefaultArgument
def __init__(self, url, headers={}, status_code=200): def __init__(self, url, headers={}, status_code=200):
self.url = url self.url = url

View File

@ -24,7 +24,7 @@ def test_version():
r = http('--version', error_exit_ok=True) r = http('--version', error_exit_ok=True)
assert r.exit_status == httpie.ExitStatus.SUCCESS assert r.exit_status == httpie.ExitStatus.SUCCESS
# FIXME: py3 has version in stdout, py2 in stderr # FIXME: py3 has version in stdout, py2 in stderr
assert httpie.__version__ == r.stderr.strip() + r.strip() assert httpie.__version__ == r.strip()
def test_GET(httpbin_both): def test_GET(httpbin_both):

View File

@ -11,7 +11,7 @@ from utils import MockEnvironment, mk_config_dir, http, HTTP_OK
from fixtures import UNICODE from fixtures import UNICODE
class SessionTestBase(object): class SessionTestBase:
def start_session(self, httpbin): def start_session(self, httpbin):
"""Create and reuse a unique config dir for each test.""" """Create and reuse a unique config dir for each test."""
@ -44,7 +44,7 @@ class TestSessionFlow(SessionTestBase):
authorization, and response cookies. authorization, and response cookies.
""" """
super(TestSessionFlow, self).start_session(httpbin) super().start_session(httpbin)
r1 = http('--follow', '--session=test', '--auth=username:password', r1 = http('--follow', '--session=test', '--auth=username:password',
'GET', httpbin.url + '/cookies/set?hello=world', 'GET', httpbin.url + '/cookies/set?hello=world',
'Hello:World', 'Hello:World',
@ -130,12 +130,12 @@ class TestSession(SessionTestBase):
def test_session_by_path(self, httpbin): def test_session_by_path(self, httpbin):
self.start_session(httpbin) self.start_session(httpbin)
session_path = os.path.join(self.config_dir, 'session-by-path.json') session_path = self.config_dir / 'session-by-path.json'
r1 = http('--session=' + session_path, 'GET', httpbin.url + '/get', r1 = http('--session', str(session_path), 'GET', httpbin.url + '/get',
'Foo:Bar', env=self.env()) 'Foo:Bar', env=self.env())
assert HTTP_OK in r1 assert HTTP_OK in r1
r2 = http('--session=' + session_path, 'GET', httpbin.url + '/get', r2 = http('--session', str(session_path), 'GET', httpbin.url + '/get',
env=self.env()) env=self.env())
assert HTTP_OK in r2 assert HTTP_OK in r2
assert r2.json['headers']['Foo'] == 'Bar' assert r2.json['headers']['Foo'] == 'Bar'

View File

@ -5,8 +5,11 @@ import sys
import time import time
import json import json
import tempfile import tempfile
from pathlib import Path
from typing import Optional
from httpie import ExitStatus, EXIT_STATUS_LABELS from httpie import ExitStatus
from httpie.config import Config
from httpie.context import Environment from httpie.context import Environment
from httpie.core import main from httpie.core import main
@ -22,9 +25,9 @@ HTTP_OK_COLOR = (
) )
def mk_config_dir(): def mk_config_dir() -> Path:
dirname = tempfile.mkdtemp(prefix='httpie_config_') dirname = tempfile.mkdtemp(prefix='httpie_config_')
return dirname return Path(dirname)
def add_auth(url, auth): def add_auth(url, auth):
@ -40,7 +43,6 @@ class MockEnvironment(Environment):
is_windows = False is_windows = False
def __init__(self, create_temp_config_dir=True, **kwargs): def __init__(self, create_temp_config_dir=True, **kwargs):
self.create_temp_config_dir = create_temp_config_dir
if 'stdout' not in kwargs: if 'stdout' not in kwargs:
kwargs['stdout'] = tempfile.TemporaryFile( kwargs['stdout'] = tempfile.TemporaryFile(
mode='w+b', mode='w+b',
@ -51,22 +53,24 @@ class MockEnvironment(Environment):
mode='w+t', mode='w+t',
prefix='httpie_stderr' prefix='httpie_stderr'
) )
super(MockEnvironment, self).__init__(**kwargs) super().__init__(**kwargs)
self._create_temp_config_dir = create_temp_config_dir
self._delete_config_dir = False self._delete_config_dir = False
self._temp_dir = Path(tempfile.gettempdir())
@property @property
def config(self): def config(self) -> Config:
if (self.create_temp_config_dir if (self._create_temp_config_dir
and not self.config_dir.startswith(tempfile.gettempdir())): and self._temp_dir not in self.config_dir.parents):
self.config_dir = mk_config_dir() self.config_dir = mk_config_dir()
self._delete_config_dir = True self._delete_config_dir = True
return super(MockEnvironment, self).config return super().config
def cleanup(self): def cleanup(self):
self.stdout.close() self.stdout.close()
self.stderr.close() self.stderr.close()
if self._delete_config_dir: if self._delete_config_dir:
assert self.config_dir.startswith(tempfile.gettempdir()) assert self._temp_dir in self.config_dir.parents
from shutil import rmtree from shutil import rmtree
rmtree(self.config_dir) rmtree(self.config_dir)
@ -77,7 +81,7 @@ class MockEnvironment(Environment):
pass pass
class BaseCLIResponse(object): class BaseCLIResponse:
""" """
Represents the result of simulated `$ http' invocation via `http()`. Represents the result of simulated `$ http' invocation via `http()`.
@ -88,9 +92,9 @@ class BaseCLIResponse(object):
- exit_status output: print(self.exit_status) - exit_status output: print(self.exit_status)
""" """
stderr = None stderr: str = None
json = None json: dict = None
exit_status = None exit_status: ExitStatus = None
class BytesCLIResponse(bytes, BaseCLIResponse): class BytesCLIResponse(bytes, BaseCLIResponse):
@ -107,7 +111,7 @@ class BytesCLIResponse(bytes, BaseCLIResponse):
class StrCLIResponse(str, BaseCLIResponse): class StrCLIResponse(str, BaseCLIResponse):
@property @property
def json(self): def json(self) -> Optional[dict]:
""" """
Return deserialized JSON body, if one included in the output Return deserialized JSON body, if one included in the output
and is parsable. and is parsable.
@ -132,6 +136,7 @@ class StrCLIResponse(str, BaseCLIResponse):
pass pass
else: else:
try: try:
# noinspection PyAttributeOutsideInit
self._json = json.loads(j) self._json = json.loads(j)
except ValueError: except ValueError:
pass pass
@ -174,7 +179,7 @@ def http(*args, program_name='http', **kwargs):
>>> type(r) == StrCLIResponse >>> type(r) == StrCLIResponse
True True
>>> r.exit_status >>> r.exit_status
0 <ExitStatus.SUCCESS: 0>
>>> r.stderr >>> r.stderr
'' ''
>>> 'HTTP/1.1 200 OK' in r >>> 'HTTP/1.1 200 OK' in r
@ -227,10 +232,7 @@ def http(*args, program_name='http', **kwargs):
dump_stderr() dump_stderr()
raise ExitStatusError( raise ExitStatusError(
'httpie.core.main() unexpectedly returned' 'httpie.core.main() unexpectedly returned'
' a non-zero exit status: {0} ({1})'.format( f' a non-zero exit status: {exit_status}'
exit_status,
EXIT_STATUS_LABELS[exit_status]
)
) )
stdout.seek(0) stdout.seek(0)
@ -239,10 +241,8 @@ def http(*args, program_name='http', **kwargs):
try: try:
output = output.decode('utf8') output = output.decode('utf8')
except UnicodeDecodeError: except UnicodeDecodeError:
# noinspection PyArgumentList
r = BytesCLIResponse(output) r = BytesCLIResponse(output)
else: else:
# noinspection PyArgumentList
r = StrCLIResponse(output) r = StrCLIResponse(output)
r.stderr = stderr.read() r.stderr = stderr.read()
r.exit_status = exit_status r.exit_status = exit_status