From 81c89ce9be0c77f0b777ceffdce882bd8ff68f9a Mon Sep 17 00:00:00 2001 From: Avery Pennarun Date: Sun, 2 May 2010 02:23:42 -0400 Subject: [PATCH] Don't bother with a backtrace when we produce certain fatal errors. We'll introduce a new "Fatal" exception for this purpose, and throw it when we just want to print a user message and abort immediately. --- client.py | 20 ++++++++++++-------- helpers.py | 4 ++++ main.py | 54 +++++++++++++++++++++++++++++++----------------------- server.py | 7 +++++-- ssh.py | 5 ++--- 5 files changed, 54 insertions(+), 36 deletions(-) diff --git a/client.py b/client.py index 42da341..a48c2ef 100644 --- a/client.py +++ b/client.py @@ -20,27 +20,31 @@ def iptables_setup(port, subnets): ['--iptables', str(port)] + subnets_str) rv = subprocess.call(argv) if rv != 0: - raise Exception('%r returned %d' % (argv, rv)) + raise Fatal('%r returned %d' % (argv, rv)) def _main(listener, listenport, use_server, remotename, subnets): handlers = [] if use_server: - helpers.logprefix = 'c : ' + if helpers.verbose >= 1: + helpers.logprefix = 'c : ' + else: + helpers.logprefix = 'client: ' (serverproc, serversock) = ssh.connect(remotename) mux = Mux(serversock, serversock) handlers.append(mux) expected = 'SSHUTTLE0001' initstring = serversock.recv(len(expected)) - if initstring != expected: - raise Exception('expected server init string %r; got %r' - % (expected, initstring)) - + rv = serverproc.poll() if rv: - raise Exception('server died with error code %d' % rv) + raise Fatal('server died with error code %d' % rv) + if initstring != expected: + raise Fatal('expected server init string %r; got %r' + % (expected, initstring)) + # we definitely want to do this *after* starting ssh, or we might end # up intercepting the ssh connection! iptables_setup(listenport, subnets) @@ -67,7 +71,7 @@ def _main(listener, listenport, use_server, remotename, subnets): if use_server: rv = serverproc.poll() if rv: - raise Exception('server died with error code %d' % rv) + raise Fatal('server died with error code %d' % rv) r = set() w = set() diff --git a/helpers.py b/helpers.py index 5ffada5..23f3c84 100644 --- a/helpers.py +++ b/helpers.py @@ -15,3 +15,7 @@ def debug1(s): def debug2(s): if verbose >= 2: log(s) + + +class Fatal(Exception): + pass diff --git a/main.py b/main.py index b83502b..1e4756b 100755 --- a/main.py +++ b/main.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import sys, os, re import helpers, options, client, server, iptables +from helpers import * # list of: @@ -10,7 +11,7 @@ def parse_subnets(subnets_str): for s in subnets_str: m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s) if not m: - raise Exception('%r is not a valid IP subnet format' % s) + raise Fatal('%r is not a valid IP subnet format' % s) (a,b,c,d,width) = m.groups() (a,b,c,d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0)) if width == None: @@ -18,9 +19,9 @@ def parse_subnets(subnets_str): else: width = int(width) if a > 255 or b > 255 or c > 255 or d > 255: - raise Exception('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) + raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) if width > 32: - raise Exception('*/%d is greater than the maximum of 32' % width) + raise Fatal('*/%d is greater than the maximum of 32' % width) subnets.append(('%d.%d.%d.%d' % (a,b,c,d), width)) return subnets @@ -30,14 +31,14 @@ def parse_ipport(s): s = str(s) m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s) if not m: - raise Exception('%r is not a valid IP:port format' % s) + raise Fatal('%r is not a valid IP:port format' % s) (a,b,c,d,port) = m.groups() (a,b,c,d,port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0), int(port or 0)) if a > 255 or b > 255 or c > 255 or d > 255: - raise Exception('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) + raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) if port > 65535: - raise Exception('*:%d is greater than the maximum of 65535' % port) + raise Fatal('*:%d is greater than the maximum of 65535' % port) if a == None: a = b = c = d = 0 return ('%d.%d.%d.%d' % (a,b,c,d), port) @@ -60,20 +61,27 @@ o = options.Options('sshuttle', optspec) helpers.verbose = opt.verbose -if opt.server: - sys.exit(server.main()) -elif opt.iptables: - if len(extra) < 1: - o.fatal('at least one argument expected') - sys.exit(iptables.main(int(extra[0]), - parse_subnets(extra[1:]))) -else: - if len(extra) < 1: - o.fatal('at least one subnet expected') - remotename = opt.remote - if remotename == '' or remotename == '-': - remotename = None - sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), - not opt.noserver, - remotename, - parse_subnets(extra))) +try: + if opt.server: + sys.exit(server.main()) + elif opt.iptables: + if len(extra) < 1: + o.fatal('at least one argument expected') + sys.exit(iptables.main(int(extra[0]), + parse_subnets(extra[1:]))) + else: + if len(extra) < 1: + o.fatal('at least one subnet expected') + remotename = opt.remote + if remotename == '' or remotename == '-': + remotename = None + sys.exit(client.main(parse_ipport(opt.listen or '0.0.0.0:0'), + not opt.noserver, + remotename, + parse_subnets(extra))) +except Fatal, e: + log('fatal: %s\n' % e) + sys.exit(99) +except KeyboardInterrupt: + log('\nKeyboard interrupt: exiting.\n') + sys.exit(1) diff --git a/server.py b/server.py index b050c82..b0699e6 100644 --- a/server.py +++ b/server.py @@ -8,8 +8,11 @@ def main(): # synchronization header sys.stdout.write('SSHUTTLE0001') sys.stdout.flush() - - helpers.logprefix = ' s: ' + + if helpers.verbose >= 1: + helpers.logprefix = ' s: ' + else: + helpers.logprefix = 'server: ' handlers = [] mux = Mux(socket.fromfd(sys.stdin.fileno(), socket.AF_INET, socket.SOCK_STREAM), diff --git a/ssh.py b/ssh.py index 4c983b9..d778245 100644 --- a/ssh.py +++ b/ssh.py @@ -19,10 +19,9 @@ def connect(rhost): # stuff here. escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir) cmd = r""" - sh -c PATH=%s:'$PATH exec sshuttle --server' - """ % (escapedir,) + sh -c PATH=%s:'$PATH exec sshuttle --server%s' + """ % (escapedir, ' -v' * (helpers.verbose or 0)) argv = ['ssh', rhost, '--', cmd.strip()] - print repr(argv) (s1,s2) = socket.socketpair() def setup(): # runs in the child process