mirror of
https://github.com/httpie/cli.git
synced 2024-11-25 17:23:25 +01:00
276 lines
7.8 KiB
Python
276 lines
7.8 KiB
Python
import argparse
|
||
import getpass
|
||
import os
|
||
import sys
|
||
from copy import deepcopy
|
||
from typing import List, Optional, Union
|
||
|
||
from .constants import DEFAULT_FORMAT_OPTIONS, SEPARATOR_CREDENTIALS
|
||
from ..sessions import VALID_SESSION_NAME_PATTERN
|
||
|
||
|
||
class KeyValueArg:
|
||
"""Base key-value pair parsed from CLI."""
|
||
|
||
def __init__(self, key: str, value: Optional[str], sep: str, orig: str):
|
||
self.key = key
|
||
self.value = value
|
||
self.sep = sep
|
||
self.orig = orig
|
||
|
||
def __eq__(self, other: 'KeyValueArg'):
|
||
return self.__dict__ == other.__dict__
|
||
|
||
def __repr__(self):
|
||
return repr(self.__dict__)
|
||
|
||
|
||
class SessionNameValidator:
|
||
|
||
def __init__(self, error_message: str):
|
||
self.error_message = error_message
|
||
|
||
def __call__(self, value: str) -> str:
|
||
# Session name can be a path or just a name.
|
||
if (os.path.sep not in value
|
||
and not VALID_SESSION_NAME_PATTERN.search(value)):
|
||
raise argparse.ArgumentError(None, self.error_message)
|
||
return value
|
||
|
||
|
||
class Escaped(str):
|
||
"""Represents an escaped character."""
|
||
|
||
def __repr__(self):
|
||
return f"Escaped({repr(str(self))})"
|
||
|
||
|
||
class KeyValueArgType:
|
||
"""A key-value pair argument type used with `argparse`.
|
||
|
||
Parses a key-value arg and constructs a `KeyValueArg` instance.
|
||
Used for headers, form data, and other key-value pair types.
|
||
|
||
"""
|
||
|
||
key_value_class = KeyValueArg
|
||
|
||
def __init__(self, *separators: str):
|
||
self.separators = separators
|
||
self.special_characters = set()
|
||
for separator in separators:
|
||
self.special_characters.update(separator)
|
||
|
||
def __call__(self, s: str) -> KeyValueArg:
|
||
"""Parse raw string arg and return `self.key_value_class` instance.
|
||
|
||
The best of `self.separators` is determined (first found, longest).
|
||
Back slash escaped characters aren't considered as separators
|
||
(or parts thereof). Literal back slash characters have to be escaped
|
||
as well (r'\\').
|
||
|
||
"""
|
||
tokens = self.tokenize(s)
|
||
|
||
# Sorting by length ensures that the longest one will be
|
||
# chosen as it will overwrite any shorter ones starting
|
||
# at the same position in the `found` dictionary.
|
||
separators = sorted(self.separators, key=len)
|
||
|
||
for i, token in enumerate(tokens):
|
||
|
||
if isinstance(token, Escaped):
|
||
continue
|
||
|
||
found = {}
|
||
for sep in separators:
|
||
pos = token.find(sep)
|
||
if pos != -1:
|
||
found[pos] = sep
|
||
|
||
if found:
|
||
# Starting first, longest separator found.
|
||
sep = found[min(found.keys())]
|
||
|
||
key, value = token.split(sep, 1)
|
||
|
||
# Any preceding tokens are part of the key.
|
||
key = ''.join(tokens[:i]) + key
|
||
|
||
# Any following tokens are part of the value.
|
||
value += ''.join(tokens[i + 1:])
|
||
|
||
break
|
||
|
||
else:
|
||
raise argparse.ArgumentTypeError(f'{s!r} is not a valid value')
|
||
|
||
return self.key_value_class(key=key, value=value, sep=sep, orig=s)
|
||
|
||
def tokenize(self, s: str) -> List[Union[str, Escaped]]:
|
||
r"""Tokenize the raw arg string
|
||
|
||
There are only two token types - strings and escaped characters:
|
||
|
||
>>> KeyValueArgType('=').tokenize(r'foo\=bar\\baz')
|
||
['foo', Escaped('='), 'bar\\\\baz']
|
||
|
||
"""
|
||
tokens = ['']
|
||
characters = iter(s)
|
||
for char in characters:
|
||
if char == '\\':
|
||
char = next(characters, '')
|
||
if char not in self.special_characters:
|
||
tokens[-1] += '\\' + char
|
||
else:
|
||
tokens.extend([Escaped(char), ''])
|
||
else:
|
||
tokens[-1] += char
|
||
return tokens
|
||
|
||
|
||
class PromptMixin:
|
||
def _prompt_password(self, prompt: str) -> str:
|
||
prompt_text = f'http: {prompt}: '
|
||
try:
|
||
return self._getpass(prompt_text)
|
||
except (EOFError, KeyboardInterrupt):
|
||
sys.stderr.write('\n')
|
||
sys.exit(0)
|
||
|
||
@staticmethod
|
||
def _getpass(prompt):
|
||
# To allow easy mocking.
|
||
return getpass.getpass(str(prompt))
|
||
|
||
|
||
class SSLCredentials(PromptMixin):
|
||
"""Represents the passphrase for the certificate's key."""
|
||
|
||
def __init__(self, value: Optional[str]) -> None:
|
||
self.value = value
|
||
|
||
def prompt_password(self, key_file: str) -> None:
|
||
self.value = self._prompt_password(f'passphrase for {key_file}')
|
||
|
||
|
||
class AuthCredentials(KeyValueArg, PromptMixin):
|
||
"""Represents parsed credentials."""
|
||
|
||
def has_password(self) -> bool:
|
||
return self.value is not None
|
||
|
||
def prompt_password(self, host: str) -> None:
|
||
self.value = self._prompt_password(f'password for {self.key}@{host}:')
|
||
|
||
|
||
class AuthCredentialsArgType(KeyValueArgType):
|
||
"""A key-value arg type that parses credentials."""
|
||
|
||
key_value_class = AuthCredentials
|
||
|
||
def __call__(self, s):
|
||
"""Parse credentials from `s`.
|
||
|
||
("username" or "username:password").
|
||
|
||
"""
|
||
try:
|
||
return super().__call__(s)
|
||
except argparse.ArgumentTypeError:
|
||
# No password provided, will prompt for it later.
|
||
return self.key_value_class(
|
||
key=s,
|
||
value=None,
|
||
sep=SEPARATOR_CREDENTIALS,
|
||
orig=s
|
||
)
|
||
|
||
|
||
parse_auth = AuthCredentialsArgType(SEPARATOR_CREDENTIALS)
|
||
|
||
|
||
def readable_file_arg(filename):
|
||
try:
|
||
with open(filename, 'rb'):
|
||
return filename
|
||
except OSError as ex:
|
||
raise argparse.ArgumentTypeError(f'{ex.filename}: {ex.strerror}')
|
||
|
||
|
||
def parse_format_options(s: str, defaults: Optional[dict]) -> dict:
|
||
"""
|
||
Parse `s` and update `defaults` with the parsed values.
|
||
|
||
>>> parse_format_options(
|
||
... defaults={'json': {'indent': 4, 'sort_keys': True}},
|
||
... s='json.indent:2,json.sort_keys:False',
|
||
... )
|
||
{'json': {'indent': 2, 'sort_keys': False}}
|
||
|
||
"""
|
||
value_map = {
|
||
'true': True,
|
||
'false': False,
|
||
}
|
||
options = deepcopy(defaults or {})
|
||
for option in s.split(','):
|
||
try:
|
||
path, value = option.lower().split(':')
|
||
section, key = path.split('.')
|
||
except ValueError:
|
||
raise argparse.ArgumentTypeError(f'invalid option {option!r}')
|
||
|
||
if value in value_map:
|
||
parsed_value = value_map[value]
|
||
else:
|
||
if value.isnumeric():
|
||
parsed_value = int(value)
|
||
else:
|
||
parsed_value = value
|
||
|
||
if defaults is None:
|
||
options.setdefault(section, {})
|
||
else:
|
||
try:
|
||
default_value = defaults[section][key]
|
||
except KeyError:
|
||
raise argparse.ArgumentTypeError(
|
||
f'invalid key {path!r}')
|
||
|
||
default_type, parsed_type = type(default_value), type(parsed_value)
|
||
if parsed_type is not default_type:
|
||
raise argparse.ArgumentTypeError(
|
||
'invalid value'
|
||
f' {value!r} in {option!r}'
|
||
f' (expected {default_type.__name__}'
|
||
f' got {parsed_type.__name__})'
|
||
)
|
||
|
||
options[section][key] = parsed_value
|
||
|
||
return options
|
||
|
||
|
||
PARSED_DEFAULT_FORMAT_OPTIONS = parse_format_options(
|
||
s=','.join(DEFAULT_FORMAT_OPTIONS),
|
||
defaults=None,
|
||
)
|
||
|
||
|
||
def response_charset_type(encoding: str) -> str:
|
||
try:
|
||
''.encode(encoding)
|
||
except LookupError:
|
||
raise argparse.ArgumentTypeError(
|
||
f'{encoding!r} is not a supported encoding')
|
||
return encoding
|
||
|
||
|
||
def response_mime_type(mime_type: str) -> str:
|
||
if mime_type.count('/') != 1:
|
||
raise argparse.ArgumentTypeError(
|
||
f'{mime_type!r} doesn’t look like a mime type; use type/subtype')
|
||
return mime_type
|