diff --git a/httpie/client.py b/httpie/client.py index a48527a2..8d338576 100644 --- a/httpie/client.py +++ b/httpie/client.py @@ -4,7 +4,7 @@ import json import sys from contextlib import contextmanager from pathlib import Path -from typing import Callable, Iterable, Union +from typing import Callable, Iterable from urllib.parse import urlparse, urlunparse import requests @@ -14,6 +14,7 @@ from . import __version__ from .adapters import HTTPieHTTPAdapter from .cli.dicts import HTTPHeadersDict from .encoding import UTF8 +from .models import RequestsMessage from .plugins.registry import plugin_manager from .sessions import get_httpie_session from .ssl import AVAILABLE_SSL_VERSION_ARG_MAPPING, HTTPieHTTPSAdapter @@ -36,7 +37,7 @@ def collect_messages( args: argparse.Namespace, config_dir: Path, request_body_read_callback: Callable[[bytes], None] = None, -) -> Iterable[Union[requests.PreparedRequest, requests.Response]]: +) -> Iterable[RequestsMessage]: httpie_session = None httpie_session_headers = None if args.session or args.session_read_only: diff --git a/httpie/core.py b/httpie/core.py index 44b3f5c5..3c6c1021 100644 --- a/httpie/core.py +++ b/httpie/core.py @@ -13,6 +13,11 @@ from .cli.constants import OUT_REQ_BODY, OUT_REQ_HEAD, OUT_RESP_BODY, OUT_RESP_H from .client import collect_messages from .context import Environment from .downloads import Downloader +from .models import ( + RequestsMessage, + RequestsMessageKind, + infer_requests_message_kind +) from .output.writer import write_message, write_stream, MESSAGE_SEPARATOR_BYTES from .plugins.registry import plugin_manager from .status import ExitStatus, http_status_to_exit_status @@ -111,18 +116,18 @@ def main(args: List[Union[str, bytes]] = sys.argv, env=Environment()) -> ExitSta def get_output_options( args: argparse.Namespace, - message: Union[requests.PreparedRequest, requests.Response] + message: RequestsMessage ) -> Tuple[bool, bool]: return { - requests.PreparedRequest: ( + RequestsMessageKind.REQUEST: ( OUT_REQ_HEAD in args.output_options, OUT_REQ_BODY in args.output_options, ), - requests.Response: ( + RequestsMessageKind.RESPONSE: ( OUT_RESP_HEAD in args.output_options, OUT_RESP_BODY in args.output_options, ), - }[type(message)] + }[infer_requests_message_kind(message)] def program(args: argparse.Namespace, env: Environment) -> ExitStatus: diff --git a/httpie/models.py b/httpie/models.py index 64079d0c..af3e5a98 100644 --- a/httpie/models.py +++ b/httpie/models.py @@ -1,4 +1,7 @@ -from typing import Iterable +import requests + +from enum import Enum, auto +from typing import Iterable, Union from urllib.parse import urlsplit from .utils import split_cookies, parse_content_type_header @@ -118,3 +121,20 @@ class HTTPRequest(HTTPMessage): # Happens with JSON/form request data parsed from the command line. body = body.encode() return body or b'' + + +RequestsMessage = Union[requests.PreparedRequest, requests.Response] + + +class RequestsMessageKind(Enum): + REQUEST = auto() + RESPONSE = auto() + + +def infer_requests_message_kind(message: RequestsMessage) -> RequestsMessageKind: + if isinstance(message, requests.PreparedRequest): + return RequestsMessageKind.REQUEST + elif isinstance(message, requests.Response): + return RequestsMessageKind.RESPONSE + else: + raise TypeError(f"Unexpected message type: {type(message).__name__}") diff --git a/httpie/output/writer.py b/httpie/output/writer.py index 6f251f7c..4650264d 100644 --- a/httpie/output/writer.py +++ b/httpie/output/writer.py @@ -2,10 +2,15 @@ import argparse import errno from typing import IO, TextIO, Tuple, Type, Union -import requests - from ..context import Environment -from ..models import HTTPRequest, HTTPResponse, HTTPMessage +from ..models import ( + HTTPRequest, + HTTPResponse, + HTTPMessage, + RequestsMessage, + RequestsMessageKind, + infer_requests_message_kind +) from .processing import Conversion, Formatting from .streams import ( BaseStream, BufferedPrettyStream, EncodedStream, PrettyStream, RawStream, @@ -17,7 +22,7 @@ MESSAGE_SEPARATOR_BYTES = MESSAGE_SEPARATOR.encode() def write_message( - requests_message: Union[requests.PreparedRequest, requests.Response], + requests_message: RequestsMessage, env: Environment, args: argparse.Namespace, with_headers=False, @@ -93,14 +98,14 @@ def write_stream_with_colors_win( def build_output_stream_for_message( args: argparse.Namespace, env: Environment, - requests_message: Union[requests.PreparedRequest, requests.Response], + requests_message: RequestsMessage, with_headers: bool, with_body: bool, ): message_type = { - requests.PreparedRequest: HTTPRequest, - requests.Response: HTTPResponse, - }[type(requests_message)] + RequestsMessageKind.REQUEST: HTTPRequest, + RequestsMessageKind.RESPONSE: HTTPResponse, + }[infer_requests_message_kind(requests_message)] stream_class, stream_kwargs = get_stream_type_and_kwargs( env=env, args=args,