Added tests for sessions.

This commit is contained in:
Jakub Roztocil 2012-09-17 02:15:00 +02:00
parent 6c2001d1f5
commit 548bef7dff
7 changed files with 153 additions and 19 deletions

View File

@ -9,6 +9,7 @@ from requests.compat import is_windows
from . import __doc__ from . import __doc__
from . import __version__ from . import __version__
from .config import DEFAULT_CONFIG_DIR
from .output import AVAILABLE_STYLES, DEFAULT_STYLE from .output import AVAILABLE_STYLES, DEFAULT_STYLE
from .input import (Parser, AuthCredentialsArgType, KeyValueArgType, from .input import (Parser, AuthCredentialsArgType, KeyValueArgType,
SEP_PROXY, SEP_CREDENTIALS, SEP_GROUP_ITEMS, SEP_PROXY, SEP_CREDENTIALS, SEP_GROUP_ITEMS,
@ -330,6 +331,7 @@ network.add_argument(
############################################################################### ###############################################################################
troubleshooting = parser.add_argument_group(title='Troubleshooting') troubleshooting = parser.add_argument_group(title='Troubleshooting')
troubleshooting.add_argument( troubleshooting.add_argument(
'--help', '--help',
action='help', default=SUPPRESS, action='help', default=SUPPRESS,

View File

@ -15,7 +15,7 @@ JSON = 'application/json; charset=utf-8'
DEFAULT_UA = 'HTTPie/%s' % __version__ DEFAULT_UA = 'HTTPie/%s' % __version__
def get_response(args): def get_response(args, config_dir):
"""Send the request and return a `request.Response`.""" """Send the request and return a `request.Response`."""
requests_kwargs = get_requests_kwargs(args) requests_kwargs = get_requests_kwargs(args)
@ -28,6 +28,7 @@ def get_response(args):
return requests.request(**requests_kwargs) return requests.request(**requests_kwargs)
else: else:
return sessions.get_response( return sessions.get_response(
config_dir=config_dir,
name=args.session or args.session_read_only, name=args.session or args.session_read_only,
request_kwargs=requests_kwargs, request_kwargs=requests_kwargs,
read_only=bool(args.session_read_only), read_only=bool(args.session_read_only),

View File

@ -17,10 +17,12 @@ class BaseConfigDict(dict):
name = None name = None
help = None help = None
directory=DEFAULT_CONFIG_DIR
def __init__(self, directory=DEFAULT_CONFIG_DIR, seq=None, **kwargs): def __init__(self, directory=None, *args, **kwargs):
super(BaseConfigDict, self).__init__(seq or [], **kwargs) super(BaseConfigDict, self).__init__(*args, **kwargs)
self.directory = directory if directory:
self.directory = directory
def __getattr__(self, item): def __getattr__(self, item):
return self[item] return self[item]
@ -73,8 +75,9 @@ class Config(BaseConfigDict):
DEFAULTS = { DEFAULTS = {
'default_content_type': 'json', 'default_content_type': 'json',
'default_options': []
} }
def __init__(self, seq=None, **kwargs): def __init__(self, *args, **kwargs):
super(Config, self).__init__(seq or [], **kwargs) super(Config, self).__init__(*args, **kwargs)
self.update(self.DEFAULTS) self.update(self.DEFAULTS)

View File

@ -58,6 +58,8 @@ def main(args=sys.argv[1:], env=Environment()):
Return exit status. Return exit status.
""" """
if env.config.default_options:
args = env.config.default_options + args
def error(msg, *args): def error(msg, *args):
msg = msg % args msg = msg % args
@ -74,7 +76,8 @@ def main(args=sys.argv[1:], env=Environment()):
try: try:
args = parser.parse_args(args=args, env=env) args = parser.parse_args(args=args, env=env)
response = get_response(args)
response = get_response(args, config_dir=env.config.directory)
if args.check_status: if args.check_status:
status = get_exist_status(response.status_code, status = get_exist_status(response.status_code,

View File

@ -45,8 +45,7 @@ class Environment(object):
@property @property
def config(self): def config(self):
if not hasattr(self, '_config'): if not hasattr(self, '_config'):
self._config = Config() self._config = Config(directory=self.config_dir)
self._config.directory = self.config_dir
if self._config.is_new: if self._config.is_new:
self._config.save() self._config.save()
else: else:

View File

@ -15,20 +15,24 @@ from requests.cookies import RequestsCookieJar, create_cookie
from requests.auth import HTTPBasicAuth, HTTPDigestAuth from requests.auth import HTTPBasicAuth, HTTPDigestAuth
from argparse import OPTIONAL from argparse import OPTIONAL
from .config import DEFAULT_CONFIG_DIR, BaseConfigDict from .config import BaseConfigDict, DEFAULT_CONFIG_DIR
from .output import PygmentsProcessor from .output import PygmentsProcessor
SESSIONS_DIR = os.path.join(DEFAULT_CONFIG_DIR, 'sessions') SESSIONS_DIR_NAME = 'sessions'
def get_response(name, request_kwargs, read_only=False): def get_response(name, request_kwargs, config_dir, read_only=False):
"""Like `client.get_response`, but applies permanent """Like `client.get_response`, but applies permanent
aspects of the session to the request. aspects of the session to the request.
""" """
host = Host(request_kwargs['headers'].get('Host', None) sessions_dir = os.path.join(config_dir, SESSIONS_DIR_NAME)
or urlparse(request_kwargs['url']).netloc.split('@')[-1]) host = Host(
root_dir=sessions_dir,
name=request_kwargs['headers'].get('Host', None)
or urlparse(request_kwargs['url']).netloc.split('@')[-1]
)
session = Session(host, name) session = Session(host, name)
session.load() session.load()
@ -60,8 +64,9 @@ def get_response(name, request_kwargs, read_only=False):
class Host(object): class Host(object):
"""A host is a per-host directory on the disk containing sessions files.""" """A host is a per-host directory on the disk containing sessions files."""
def __init__(self, name): def __init__(self, name, root_dir=DEFAULT_CONFIG_DIR):
self.name = name self.name = name
self.root_dir = root_dir
def __iter__(self): def __iter__(self):
"""Return a iterator yielding `(session_name, session_path)`.""" """Return a iterator yielding `(session_name, session_path)`."""
@ -76,7 +81,7 @@ class Host(object):
# Name will include ':' if a port is specified, which is invalid # Name will include ':' if a port is specified, which is invalid
# on windows. DNS does not allow '_' in a domain, or for it to end # on windows. DNS does not allow '_' in a domain, or for it to end
# in a number (I think?) # in a number (I think?)
path = os.path.join(SESSIONS_DIR, self.name.replace(':', '_')) path = os.path.join(self.root_dir, self.name.replace(':', '_'))
try: try:
os.makedirs(path, mode=0o700) os.makedirs(path, mode=0o700)
except OSError as e: except OSError as e:

View File

@ -26,8 +26,8 @@ import json
import argparse import argparse
import tempfile import tempfile
import unittest import unittest
import shutil
CRLF = '\r\n'
try: try:
from urllib.request import urlopen from urllib.request import urlopen
except ImportError: except ImportError:
@ -63,6 +63,7 @@ from httpie.output import BINARY_SUPPRESSED_NOTICE
from httpie.input import ParseError from httpie.input import ParseError
CRLF = '\r\n'
HTTPBIN_URL = os.environ.get('HTTPBIN_URL', HTTPBIN_URL = os.environ.get('HTTPBIN_URL',
'http://httpbin.org').rstrip('/') 'http://httpbin.org').rstrip('/')
@ -99,6 +100,10 @@ def httpbin(path):
return HTTPBIN_URL + path return HTTPBIN_URL + path
def mk_config_dir():
return tempfile.mkdtemp(prefix='httpie_test_config_dir_')
class TestEnvironment(Environment): class TestEnvironment(Environment):
colors = 0 colors = 0
stdin_isatty = True, stdin_isatty = True,
@ -113,8 +118,16 @@ class TestEnvironment(Environment):
if 'stderr' not in kwargs: if 'stderr' not in kwargs:
kwargs['stderr'] = tempfile.TemporaryFile('w+t') kwargs['stderr'] = tempfile.TemporaryFile('w+t')
self.delete_config_dir = False
if 'config_dir' not in kwargs:
kwargs['config_dir'] = mk_config_dir()
self.delete_config_dir = True
super(TestEnvironment, self).__init__(**kwargs) super(TestEnvironment, self).__init__(**kwargs)
def __del__(self):
if self.delete_config_dir:
shutil.rmtree(self.config_dir)
def has_docutils(): def has_docutils():
try: try:
@ -862,7 +875,6 @@ class ExitStatusTest(BaseTestCase):
self.assertEqual(r.exit_status, EXIT.OK) self.assertEqual(r.exit_status, EXIT.OK)
self.assertTrue(not r.stderr) self.assertTrue(not r.stderr)
@skip('httpbin.org always returns 500')
def test_timeout_exit_status(self): def test_timeout_exit_status(self):
r = http( r = http(
'--timeout=0.5', '--timeout=0.5',
@ -1232,7 +1244,6 @@ class ArgumentParserTestCase(unittest.TestCase):
]) ])
class READMETest(BaseTestCase): class READMETest(BaseTestCase):
@skipIf(not has_docutils(), 'docutils not installed') @skipIf(not has_docutils(), 'docutils not installed')
@ -1241,6 +1252,116 @@ class READMETest(BaseTestCase):
self.assertFalse(errors, msg=errors) self.assertFalse(errors, msg=errors)
class SessionTest(BaseTestCase):
@property
def env(self):
return TestEnvironment(config_dir=self.config_dir)
def setUp(self):
# Start a full-blown session with a custom request header,
# authorization, and response cookies.
self.config_dir = mk_config_dir()
r = http(
'--follow',
'--session=test',
'--auth=username:password',
'GET',
httpbin('/cookies/set?hello=world'),
'Hello:World',
env=self.env
)
self.assertIn(OK, r)
def tearDown(self):
shutil.rmtree(self.config_dir)
def test_session_create(self):
# Verify that the has been created
r = http(
'--session=test',
'GET',
httpbin('/get'),
env=self.env
)
self.assertIn(OK, r)
self.assertEqual(r.json['headers']['Hello'], 'World')
self.assertEqual(r.json['headers']['Cookie'], 'hello=world')
self.assertIn('Basic ', r.json['headers']['Authorization'])
def test_session_update(self):
# Get a response to a request from the original session.
r1 = http(
'--session=test',
'GET',
httpbin('/get'),
env=self.env
)
self.assertIn(OK, r1)
# Make a request modifying the session data.
r2 = http(
'--follow',
'--session=test',
'--auth=username:password2',
'GET',
httpbin('/cookies/set?hello=world2'),
'Hello:World2',
env=self.env
)
self.assertIn(OK, r2)
# Get a response to a request from the updated session.
r3 = http(
'--session=test',
'GET',
httpbin('/get'),
env=self.env
)
self.assertIn(OK, r3)
self.assertEqual(r3.json['headers']['Hello'], 'World2')
self.assertEqual(r3.json['headers']['Cookie'], 'hello=world2')
self.assertNotEqual(r1.json['headers']['Authorization'],
r3.json['headers']['Authorization'])
def test_session_only(self):
# Get a response from the original session.
r1 = http(
'--session=test',
'GET',
httpbin('/get'),
env=self.env
)
self.assertIn(OK, r1)
# Make a request modifying the session data but
# with --session-read-only.
r2 = http(
'--follow',
'--session-read-only=test',
'--auth=username:password2',
'GET',
httpbin('/cookies/set?hello=world2'),
'Hello:World2',
env=self.env
)
self.assertIn(OK, r2)
# Get a response from the updated session.
r3 = http(
'--session=test',
'GET',
httpbin('/get'),
env=self.env
)
self.assertIn(OK, r3)
# Should be the same as before r2.
self.assertDictEqual(r1.json, r3.json)
if __name__ == '__main__': if __name__ == '__main__':
#noinspection PyCallingNonCallable #noinspection PyCallingNonCallable
unittest.main() unittest.main()