Typing & cleanup

This commit is contained in:
Jakub Roztocil 2019-09-17 09:21:49 +02:00
parent 37fa67cd3c
commit a42b275ae2

View File

@ -2,7 +2,7 @@ import argparse
import getpass import getpass
import os import os
import sys import sys
from typing import Union, List from typing import Union, List, Optional
from httpie.cli.constants import SEPARATOR_CREDENTIALS from httpie.cli.constants import SEPARATOR_CREDENTIALS
from httpie.sessions import VALID_SESSION_NAME_PATTERN from httpie.sessions import VALID_SESSION_NAME_PATTERN
@ -11,13 +11,13 @@ from httpie.sessions import VALID_SESSION_NAME_PATTERN
class KeyValueArg: class KeyValueArg:
"""Base key-value pair parsed from CLI.""" """Base key-value pair parsed from CLI."""
def __init__(self, key, value, sep, orig): def __init__(self, key: str, value: Optional[str], sep: str, orig: str):
self.key = key self.key = key
self.value = value self.value = value
self.sep = sep self.sep = sep
self.orig = orig self.orig = orig
def __eq__(self, other): def __eq__(self, other: 'KeyValueArg'):
return self.__dict__ == other.__dict__ return self.__dict__ == other.__dict__
def __repr__(self): def __repr__(self):
@ -26,10 +26,10 @@ class KeyValueArg:
class SessionNameValidator: class SessionNameValidator:
def __init__(self, error_message): def __init__(self, error_message: str):
self.error_message = error_message self.error_message = error_message
def __call__(self, value): def __call__(self, value: str) -> str:
# Session name can be a path or just a name. # Session name can be a path or just a name.
if (os.path.sep not in value if (os.path.sep not in value
and not VALID_SESSION_NAME_PATTERN.search(value)): and not VALID_SESSION_NAME_PATTERN.search(value)):
@ -54,14 +54,14 @@ class KeyValueArgType:
key_value_class = KeyValueArg key_value_class = KeyValueArg
def __init__(self, *separators): def __init__(self, *separators: str):
self.separators = separators self.separators = separators
self.special_characters = set('\\') self.special_characters = set('\\')
for separator in separators: for separator in separators:
self.special_characters.update(separator) self.special_characters.update(separator)
def __call__(self, string) -> KeyValueArg: def __call__(self, s: str) -> KeyValueArg:
"""Parse `string` and return `self.key_value_class()` instance. """Parse raw string arg and return `self.key_value_class` instance.
The best of `self.separators` is determined (first found, longest). The best of `self.separators` is determined (first found, longest).
Back slash escaped characters aren't considered as separators Back slash escaped characters aren't considered as separators
@ -69,7 +69,7 @@ class KeyValueArgType:
as well (r'\\'). as well (r'\\').
""" """
tokens = self.tokenize(string) tokens = self.tokenize(s)
# Sorting by length ensures that the longest one will be # Sorting by length ensures that the longest one will be
# chosen as it will overwrite any shorter ones starting # chosen as it will overwrite any shorter ones starting
@ -102,11 +102,9 @@ class KeyValueArgType:
break break
else: else:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(f'{s!r} is not a valid value')
u'"%s" is not a valid value' % string)
return self.key_value_class( return self.key_value_class(key=key, value=value, sep=sep, orig=s)
key=key, value=value, sep=sep, orig=string)
def tokenize(self, s: str) -> List[Union[str, Escaped]]: def tokenize(self, s: str) -> List[Union[str, Escaped]]:
r"""Tokenize the raw arg string r"""Tokenize the raw arg string
@ -134,42 +132,43 @@ class KeyValueArgType:
class AuthCredentials(KeyValueArg): class AuthCredentials(KeyValueArg):
"""Represents parsed credentials.""" """Represents parsed credentials."""
def _getpass(self, prompt): def has_password(self) -> bool:
# To allow mocking.
return getpass.getpass(str(prompt))
def has_password(self):
return self.value is not None return self.value is not None
def prompt_password(self, host): def prompt_password(self, host: str):
prompt_text = f'http: password for {self.key}@{host}: '
try: try:
self.value = self._getpass( self.value = self._getpass(prompt_text)
'http: password for %s@%s: ' % (self.key, host))
except (EOFError, KeyboardInterrupt): except (EOFError, KeyboardInterrupt):
sys.stderr.write('\n') sys.stderr.write('\n')
sys.exit(0) sys.exit(0)
@staticmethod
def _getpass(prompt):
# To allow easy mocking.
return getpass.getpass(str(prompt))
class AuthCredentialsArgType(KeyValueArgType): class AuthCredentialsArgType(KeyValueArgType):
"""A key-value arg type that parses credentials.""" """A key-value arg type that parses credentials."""
key_value_class = AuthCredentials key_value_class = AuthCredentials
def __call__(self, string): def __call__(self, s):
"""Parse credentials from `string`. """Parse credentials from `s`.
("username" or "username:password"). ("username" or "username:password").
""" """
try: try:
return super().__call__(string) return super().__call__(s)
except argparse.ArgumentTypeError: except argparse.ArgumentTypeError:
# No password provided, will prompt for it later. # No password provided, will prompt for it later.
return self.key_value_class( return self.key_value_class(
key=string, key=s,
value=None, value=None,
sep=SEPARATOR_CREDENTIALS, sep=SEPARATOR_CREDENTIALS,
orig=string orig=s
) )
@ -181,4 +180,4 @@ def readable_file_arg(filename):
with open(filename, 'rb'): with open(filename, 'rb'):
return filename return filename
except IOError as ex: except IOError as ex:
raise argparse.ArgumentTypeError('%s: %s' % (filename, ex.args[1])) raise argparse.ArgumentTypeError(f'{filename}: {ex.args[1]}')