Don't bother with a backtrace when we produce certain fatal errors.

We'll introduce a new "Fatal" exception for this purpose, and throw it when
we just want to print a user message and abort immediately.
This commit is contained in:
Avery Pennarun 2010-05-02 02:23:42 -04:00
parent 2dd328ada4
commit 81c89ce9be
5 changed files with 54 additions and 36 deletions

View File

@ -20,26 +20,30 @@ def iptables_setup(port, subnets):
['--iptables', str(port)] + subnets_str) ['--iptables', str(port)] + subnets_str)
rv = subprocess.call(argv) rv = subprocess.call(argv)
if rv != 0: if rv != 0:
raise Exception('%r returned %d' % (argv, rv)) raise Fatal('%r returned %d' % (argv, rv))
def _main(listener, listenport, use_server, remotename, subnets): def _main(listener, listenport, use_server, remotename, subnets):
handlers = [] handlers = []
if use_server: if use_server:
if helpers.verbose >= 1:
helpers.logprefix = 'c : ' helpers.logprefix = 'c : '
else:
helpers.logprefix = 'client: '
(serverproc, serversock) = ssh.connect(remotename) (serverproc, serversock) = ssh.connect(remotename)
mux = Mux(serversock, serversock) mux = Mux(serversock, serversock)
handlers.append(mux) handlers.append(mux)
expected = 'SSHUTTLE0001' expected = 'SSHUTTLE0001'
initstring = serversock.recv(len(expected)) initstring = serversock.recv(len(expected))
if initstring != expected:
raise Exception('expected server init string %r; got %r'
% (expected, initstring))
rv = serverproc.poll() rv = serverproc.poll()
if rv: if rv:
raise Exception('server died with error code %d' % rv) raise Fatal('server died with error code %d' % rv)
if initstring != expected:
raise Fatal('expected server init string %r; got %r'
% (expected, initstring))
# we definitely want to do this *after* starting ssh, or we might end # we definitely want to do this *after* starting ssh, or we might end
# up intercepting the ssh connection! # up intercepting the ssh connection!
@ -67,7 +71,7 @@ def _main(listener, listenport, use_server, remotename, subnets):
if use_server: if use_server:
rv = serverproc.poll() rv = serverproc.poll()
if rv: if rv:
raise Exception('server died with error code %d' % rv) raise Fatal('server died with error code %d' % rv)
r = set() r = set()
w = set() w = set()

View File

@ -15,3 +15,7 @@ def debug1(s):
def debug2(s): def debug2(s):
if verbose >= 2: if verbose >= 2:
log(s) log(s)
class Fatal(Exception):
pass

20
main.py
View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys, os, re import sys, os, re
import helpers, options, client, server, iptables import helpers, options, client, server, iptables
from helpers import *
# list of: # list of:
@ -10,7 +11,7 @@ def parse_subnets(subnets_str):
for s in subnets_str: for s in subnets_str:
m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s) m = re.match(r'(\d+)(?:\.(\d+)\.(\d+)\.(\d+))?(?:/(\d+))?$', s)
if not m: if not m:
raise Exception('%r is not a valid IP subnet format' % s) raise Fatal('%r is not a valid IP subnet format' % s)
(a,b,c,d,width) = m.groups() (a,b,c,d,width) = m.groups()
(a,b,c,d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0)) (a,b,c,d) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0))
if width == None: if width == None:
@ -18,9 +19,9 @@ def parse_subnets(subnets_str):
else: else:
width = int(width) width = int(width)
if a > 255 or b > 255 or c > 255 or d > 255: if a > 255 or b > 255 or c > 255 or d > 255:
raise Exception('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d))
if width > 32: if width > 32:
raise Exception('*/%d is greater than the maximum of 32' % width) raise Fatal('*/%d is greater than the maximum of 32' % width)
subnets.append(('%d.%d.%d.%d' % (a,b,c,d), width)) subnets.append(('%d.%d.%d.%d' % (a,b,c,d), width))
return subnets return subnets
@ -30,14 +31,14 @@ def parse_ipport(s):
s = str(s) s = str(s)
m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s) m = re.match(r'(?:(\d+)\.(\d+)\.(\d+)\.(\d+))?(?::)?(?:(\d+))?$', s)
if not m: if not m:
raise Exception('%r is not a valid IP:port format' % s) raise Fatal('%r is not a valid IP:port format' % s)
(a,b,c,d,port) = m.groups() (a,b,c,d,port) = m.groups()
(a,b,c,d,port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0), (a,b,c,d,port) = (int(a or 0), int(b or 0), int(c or 0), int(d or 0),
int(port or 0)) int(port or 0))
if a > 255 or b > 255 or c > 255 or d > 255: if a > 255 or b > 255 or c > 255 or d > 255:
raise Exception('%d.%d.%d.%d has numbers > 255' % (a,b,c,d)) raise Fatal('%d.%d.%d.%d has numbers > 255' % (a,b,c,d))
if port > 65535: if port > 65535:
raise Exception('*:%d is greater than the maximum of 65535' % port) raise Fatal('*:%d is greater than the maximum of 65535' % port)
if a == None: if a == None:
a = b = c = d = 0 a = b = c = d = 0
return ('%d.%d.%d.%d' % (a,b,c,d), port) return ('%d.%d.%d.%d' % (a,b,c,d), port)
@ -60,6 +61,7 @@ o = options.Options('sshuttle', optspec)
helpers.verbose = opt.verbose helpers.verbose = opt.verbose
try:
if opt.server: if opt.server:
sys.exit(server.main()) sys.exit(server.main())
elif opt.iptables: elif opt.iptables:
@ -77,3 +79,9 @@ else:
not opt.noserver, not opt.noserver,
remotename, remotename,
parse_subnets(extra))) parse_subnets(extra)))
except Fatal, e:
log('fatal: %s\n' % e)
sys.exit(99)
except KeyboardInterrupt:
log('\nKeyboard interrupt: exiting.\n')
sys.exit(1)

View File

@ -9,7 +9,10 @@ def main():
sys.stdout.write('SSHUTTLE0001') sys.stdout.write('SSHUTTLE0001')
sys.stdout.flush() sys.stdout.flush()
if helpers.verbose >= 1:
helpers.logprefix = ' s: ' helpers.logprefix = ' s: '
else:
helpers.logprefix = 'server: '
handlers = [] handlers = []
mux = Mux(socket.fromfd(sys.stdin.fileno(), mux = Mux(socket.fromfd(sys.stdin.fileno(),
socket.AF_INET, socket.SOCK_STREAM), socket.AF_INET, socket.SOCK_STREAM),

5
ssh.py
View File

@ -19,10 +19,9 @@ def connect(rhost):
# stuff here. # stuff here.
escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir) escapedir = re.sub(r'([^\w/])', r'\\\\\\\1', nicedir)
cmd = r""" cmd = r"""
sh -c PATH=%s:'$PATH exec sshuttle --server' sh -c PATH=%s:'$PATH exec sshuttle --server%s'
""" % (escapedir,) """ % (escapedir, ' -v' * (helpers.verbose or 0))
argv = ['ssh', rhost, '--', cmd.strip()] argv = ['ssh', rhost, '--', cmd.strip()]
print repr(argv)
(s1,s2) = socket.socketpair() (s1,s2) = socket.socketpair()
def setup(): def setup():
# runs in the child process # runs in the child process