sshuttle/sshuttle/server.py
2015-11-16 11:32:17 +11:00

332 lines
10 KiB
Python

import re
import struct
import socket
import traceback
import time
import sys
import os
import sshuttle.ssnet as ssnet
import sshuttle.helpers as helpers
import sshuttle.hostwatch as hostwatch
import subprocess as ssubprocess
from sshuttle.ssnet import Handler, Proxy, Mux, MuxWrapper
from sshuttle.helpers import log, debug1, debug2, debug3, Fatal, \
resolvconf_random_nameserver
def _ipmatch(ipstr):
if ipstr == 'default':
ipstr = '0.0.0.0/0'
m = re.match(r'^(\d+(\.\d+(\.\d+(\.\d+)?)?)?)(?:/(\d+))?$', ipstr)
if m:
g = m.groups()
ips = g[0]
width = int(g[4] or 32)
if g[1] is None:
ips += '.0.0.0'
width = min(width, 8)
elif g[2] is None:
ips += '.0.0'
width = min(width, 16)
elif g[3] is None:
ips += '.0'
width = min(width, 24)
return (struct.unpack('!I', socket.inet_aton(ips))[0], width)
def _ipstr(ip, width):
if width >= 32:
return ip
else:
return "%s/%d" % (ip, width)
def _maskbits(netmask):
if not netmask:
return 32
for i in range(32):
if netmask[0] & _shl(1, i):
return 32 - i
return 0
def _shl(n, bits):
return n * int(2 ** bits)
def _list_routes():
argv = ['netstat', '-rn']
p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE)
routes = []
for line in p.stdout:
cols = re.split(r'\s+', line)
ipw = _ipmatch(cols[0])
if not ipw:
continue # some lines won't be parseable; never mind
maskw = _ipmatch(cols[2]) # linux only
mask = _maskbits(maskw) # returns 32 if maskw is null
width = min(ipw[1], mask)
ip = ipw[0] & _shl(_shl(1, width) - 1, 32 - width)
routes.append(
(socket.AF_INET, socket.inet_ntoa(struct.pack('!I', ip)), width))
rv = p.wait()
if rv != 0:
log('WARNING: %r returned %d\n' % (argv, rv))
log('WARNING: That prevents --auto-nets from working.\n')
return routes
def list_routes():
for (family, ip, width) in _list_routes():
if not ip.startswith('0.') and not ip.startswith('127.'):
yield (family, ip, width)
def _exc_dump():
exc_info = sys.exc_info()
return ''.join(traceback.format_exception(*exc_info))
def start_hostwatch(seed_hosts):
s1, s2 = socket.socketpair()
pid = os.fork()
if not pid:
# child
rv = 99
try:
try:
s2.close()
os.dup2(s1.fileno(), 1)
os.dup2(s1.fileno(), 0)
s1.close()
rv = hostwatch.hw_main(seed_hosts) or 0
except Exception:
log('%s\n' % _exc_dump())
rv = 98
finally:
os._exit(rv)
s1.close()
return pid, s2
class Hostwatch:
def __init__(self):
self.pid = 0
self.sock = None
class DnsProxy(Handler):
def __init__(self, mux, chan, request):
# FIXME! IPv4 specific
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
Handler.__init__(self, [sock])
self.timeout = time.time() + 30
self.mux = mux
self.chan = chan
self.tries = 0
self.peer = None
self.request = request
self.sock = sock
# FIXME! IPv4 specific
self.sock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
self.try_send()
def try_send(self):
if self.tries >= 3:
return
self.tries += 1
# FIXME! Support IPv6 nameservers
self.peer = resolvconf_random_nameserver()[1]
self.sock.connect((self.peer, 53))
debug2('DNS: sending to %r\n' % self.peer)
try:
self.sock.send(self.request)
except socket.error as e:
if e.args[0] in ssnet.NET_ERRS:
# might have been spurious; try again.
# Note: these errors sometimes are reported by recv(),
# and sometimes by send(). We have to catch both.
debug2('DNS send to %r: %s\n' % (self.peer, e))
self.try_send()
return
else:
log('DNS send to %r: %s\n' % (self.peer, e))
return
def callback(self):
try:
data = self.sock.recv(4096)
except socket.error as e:
if e.args[0] in ssnet.NET_ERRS:
# might have been spurious; try again.
# Note: these errors sometimes are reported by recv(),
# and sometimes by send(). We have to catch both.
debug2('DNS recv from %r: %s\n' % (self.peer, e))
self.try_send()
return
else:
log('DNS recv from %r: %s\n' % (self.peer, e))
return
debug2('DNS response: %d bytes\n' % len(data))
self.mux.send(self.chan, ssnet.CMD_DNS_RESPONSE, data)
self.ok = False
class UdpProxy(Handler):
def __init__(self, mux, chan, family):
sock = socket.socket(family, socket.SOCK_DGRAM)
Handler.__init__(self, [sock])
self.timeout = time.time() + 30
self.mux = mux
self.chan = chan
self.sock = sock
if family == socket.AF_INET:
self.sock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
def send(self, dstip, data):
debug2('UDP: sending to %r port %d\n' % dstip)
try:
self.sock.sendto(data, dstip)
except socket.error as e:
log('UDP send to %r port %d: %s\n' % (dstip[0], dstip[1], e))
return
def callback(self):
try:
data, peer = self.sock.recvfrom(4096)
except socket.error as e:
log('UDP recv from %r port %d: %s\n' % (peer[0], peer[1], e))
return
debug2('UDP response: %d bytes\n' % len(data))
hdr = "%s,%r," % (peer[0], peer[1])
self.mux.send(self.chan, ssnet.CMD_UDP_DATA, hdr + data)
def main(latency_control):
if helpers.verbose >= 1:
helpers.logprefix = ' s: '
else:
helpers.logprefix = 'server: '
debug1('latency control setting = %r\n' % latency_control)
routes = list(list_routes())
debug1('available routes:\n')
for r in routes:
debug1(' %d/%s/%d\n' % r)
# synchronization header
sys.stdout.write('\0\0SSHUTTLE0001')
sys.stdout.flush()
handlers = []
mux = Mux(socket.fromfd(sys.stdin.fileno(),
socket.AF_INET, socket.SOCK_STREAM),
socket.fromfd(sys.stdout.fileno(),
socket.AF_INET, socket.SOCK_STREAM))
handlers.append(mux)
routepkt = ''
for r in routes:
routepkt += '%d,%s,%d\n' % r
mux.send(0, ssnet.CMD_ROUTES, routepkt)
hw = Hostwatch()
hw.leftover = ''
def hostwatch_ready():
assert(hw.pid)
content = hw.sock.recv(4096)
if content:
lines = (hw.leftover + content).split('\n')
if lines[-1]:
# no terminating newline: entry isn't complete yet!
hw.leftover = lines.pop()
lines.append('')
else:
hw.leftover = ''
mux.send(0, ssnet.CMD_HOST_LIST, '\n'.join(lines))
else:
raise Fatal('hostwatch process died')
def got_host_req(data):
if not hw.pid:
(hw.pid, hw.sock) = start_hostwatch(data.strip().split())
handlers.append(Handler(socks=[hw.sock],
callback=hostwatch_ready))
mux.got_host_req = got_host_req
def new_channel(channel, data):
(family, dstip, dstport) = data.split(',', 2)
family = int(family)
dstport = int(dstport)
outwrap = ssnet.connect_dst(family, dstip, dstport)
handlers.append(Proxy(MuxWrapper(mux, channel), outwrap))
mux.new_channel = new_channel
dnshandlers = {}
def dns_req(channel, data):
debug2('Incoming DNS request channel=%d.\n' % channel)
h = DnsProxy(mux, channel, data)
handlers.append(h)
dnshandlers[channel] = h
mux.got_dns_req = dns_req
udphandlers = {}
def udp_req(channel, cmd, data):
debug2('Incoming UDP request channel=%d, cmd=%d\n' % (channel, cmd))
if cmd == ssnet.CMD_UDP_DATA:
(dstip, dstport, data) = data.split(",", 2)
dstport = int(dstport)
debug2('is incoming UDP data. %r %d.\n' % (dstip, dstport))
h = udphandlers[channel]
h.send((dstip, dstport), data)
elif cmd == ssnet.CMD_UDP_CLOSE:
debug2('is incoming UDP close\n')
h = udphandlers[channel]
h.ok = False
del mux.channels[channel]
def udp_open(channel, data):
debug2('Incoming UDP open.\n')
family = int(data)
mux.channels[channel] = lambda cmd, data: udp_req(channel, cmd, data)
if channel in udphandlers:
raise Fatal('UDP connection channel %d already open' % channel)
else:
h = UdpProxy(mux, channel, family)
handlers.append(h)
udphandlers[channel] = h
mux.got_udp_open = udp_open
while mux.ok:
if hw.pid:
assert(hw.pid > 0)
(rpid, rv) = os.waitpid(hw.pid, os.WNOHANG)
if rpid:
raise Fatal(
'hostwatch exited unexpectedly: code 0x%04x\n' % rv)
ssnet.runonce(handlers, mux)
if latency_control:
mux.check_fullness()
mux.callback()
if dnshandlers:
now = time.time()
for channel, h in list(dnshandlers.items()):
if h.timeout < now or not h.ok:
debug3('expiring dnsreqs channel=%d\n' % channel)
del dnshandlers[channel]
h.ok = False
if udphandlers:
for channel, h in list(udphandlers.items()):
if not h.ok:
debug3('expiring UDP channel=%d\n' % channel)
del udphandlers[channel]
h.ok = False