TProxy UDP support, including DNS.

This commit is contained in:
Brian May 2011-08-22 12:03:28 +10:00
parent 20254bab57
commit 5e8ad544ee
5 changed files with 327 additions and 35 deletions

232
client.py
View File

@ -1,9 +1,26 @@
import struct, socket, select, errno, re, signal, time
import struct, select, errno, re, signal, time
import compat.ssubprocess as ssubprocess
import helpers, ssnet, ssh, ssyslog
from ssnet import SockWrapper, Handler, Proxy, Mux, MuxWrapper
from helpers import *
recvmsg = None
try:
# try getting recvmsg from python
import socket as pythonsocket
getattr(pythonsocket.socket,"recvmsg")
socket = pythonsocket
recvmsg = "python"
except AttributeError:
# try getting recvmsg from socket_ext library
try:
import socket_ext
getattr(socket_ext.socket,"recvmsg")
socket = socket_ext
recvmsg = "socket_ext"
except ImportError:
import socket
_extra_fd = os.open('/dev/null', os.O_RDONLY)
def got_signal(signum, frame):
@ -13,6 +30,80 @@ def got_signal(signum, frame):
_pidname = None
IP_TRANSPARENT = 19
IP_ORIGDSTADDR = 20
IP_RECVORIGDSTADDR = IP_ORIGDSTADDR
SOL_IPV6 = 41
IPV6_ORIGDSTADDR = 74
IPV6_RECVORIGDSTADDR = IPV6_ORIGDSTADDR
if recvmsg == "python":
def recv_udp(listener, bufsize):
debug3('Accept UDP python using recvmsg.\n')
data, ancdata, msg_flags, srcip = listener.recvmsg(4096,socket.CMSG_SPACE(24))
dstip = None
family = None
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level == socket.SOL_IP and cmsg_type == IP_ORIGDSTADDR:
family,port = struct.unpack('=HH', cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET:
start = 4
length = 4
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, cmsg_data[start:start+length])
dstip = (ip, port)
break
elif cmsg_level == SOL_IPV6 and cmsg_type == IPV6_ORIGDSTADDR:
family,port = struct.unpack('=HH', cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET6:
start = 8
length = 16
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, cmsg_data[start:start+length])
dstip = (ip, port)
break
return (srcip, dstip, data)
elif recvmsg == "socket_ext":
def recv_udp(listener, bufsize):
debug3('Accept UDP using socket_ext recvmsg.\n')
srcip, data, adata, flags = listener.recvmsg((bufsize,),socket.CMSG_SPACE(24))
dstip = None
family = None
for a in adata:
if a.cmsg_level == socket.SOL_IP and a.cmsg_type == IP_ORIGDSTADDR:
family,port = struct.unpack('=HH', a.cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET:
start = 4
length = 4
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, a.cmsg_data[start:start+length])
dstip = (ip, port)
break
elif a.cmsg_level == SOL_IPV6 and a.cmsg_type == IPV6_ORIGDSTADDR:
family,port = struct.unpack('=HH', a.cmsg_data[0:4])
port = socket.htons(port)
if family == socket.AF_INET6:
start = 8
length = 16
else:
raise Fatal("Unsupported socket type '%s'"%family)
ip = socket.inet_ntop(family, a.cmsg_data[start:start+length])
dstip = (ip, port)
break
return (srcip, dstip, data[0])
else:
def recv_udp(listener, bufsize):
debug3('Accept UDP using recvfrom.\n')
data, srcip = listener.recvfrom(bufsize)
return (srcip, None, data)
def check_daemon(pidfile):
global _pidname
_pidname = os.path.abspath(pidfile)
@ -140,7 +231,7 @@ class MultiListener:
class FirewallClient:
def __init__(self, port_v6, port_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method):
def __init__(self, port_v6, port_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method, udp):
self.auto_nets = []
self.subnets_include = subnets_include
self.subnets_exclude = subnets_exclude
@ -148,7 +239,7 @@ class FirewallClient:
['-v'] * (helpers.verbose or 0) +
['--firewall', str(port_v6), str(port_v4),
str(dnsport_v6), str(dnsport_v4),
method])
method, str(int(udp))])
if ssyslog._p:
argvbase += ['--syslog']
argv_tries = [
@ -221,13 +312,21 @@ class FirewallClient:
dnsreqs = {}
udp_by_src = {}
def expire_connections(now, mux):
for chan,(peer,sock,timeout) in dnsreqs.items():
for chan,timeout in dnsreqs.items():
if timeout < now:
debug3('expiring dnsreqs channel=%d peer=%r\n' % (chan, peer))
debug3('expiring dnsreqs channel=%d\n' % chan)
del mux.channels[chan]
del dnsreqs[chan]
debug3('Remaining DNS requests: %d\n' % len(dnsreqs))
for peer,(chan,timeout) in udp_by_src.items():
if timeout < now:
debug3('expiring UDP channel channel=%d peer=%r\n' % (chan, peer))
mux.send(chan, ssnet.CMD_UDP_CLOSE, '')
del mux.channels[chan]
del udp_by_src[peer]
debug3('Remaining UDP channels: %d\n' % len(udp_by_src))
def onaccept_tcp(listener, method, mux, handlers):
@ -268,29 +367,75 @@ def onaccept_tcp(listener, method, mux, handlers):
expire_connections(time.time(), mux)
def dns_done(chan, mux, data, method):
peer,sock,timeout = dnsreqs.get(chan) or (None,None,None)
debug3('dns_done: channel=%r peer=%r\n' % (chan, peer))
if peer:
del mux.channels[chan]
del dnsreqs[chan]
debug3('doing sendto %r\n' % (peer,))
sock.sendto(data, peer)
def udp_done(chan, data, method, family, dstip):
(src,srcport,data) = data.split(",",2)
srcip = (src,int(srcport))
debug3('doing send from %r to %r\n' % (srcip,dstip,))
try:
sender = socket.socket(family, socket.SOCK_DGRAM)
sender.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sender.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
sender.bind(srcip)
sender.sendto(data, dstip)
sender.close()
except socket.error, e:
debug1('-- ignored socket error sending UDP data: %r\n'%e)
def ondns(listener, method, mux, handlers):
pkt,peer = listener.recvfrom(4096)
def onaccept_udp(listener, method, mux, handlers):
now = time.time()
if pkt:
debug1('DNS request from %r: %d bytes\n' % (peer, len(pkt)))
srcip, dstip, data = recv_udp(listener, 4096)
if not dstip:
debug1("-- ignored UDP from %r: couldn't determine destination IP address\n" % (srcip,))
return
debug1('Accept UDP: %r -> %r.\n' % (srcip,dstip,))
if srcip in udp_by_src:
chan,timeout = udp_by_src[srcip]
else:
chan = mux.next_channel()
dnsreqs[chan] = peer,listener,now+30
mux.send(chan, ssnet.CMD_DNS_REQ, pkt)
mux.channels[chan] = lambda cmd,data: dns_done(chan, mux, data, method)
mux.channels[chan] = lambda cmd,data: udp_done(chan, data, method, listener.family, dstip=srcip)
mux.send(chan, ssnet.CMD_UDP_OPEN, listener.family)
udp_by_src[srcip] = chan,now+30
hdr = "%s,%r,"%(dstip[0], dstip[1])
mux.send(chan, ssnet.CMD_UDP_DATA, hdr+data)
expire_connections(now, mux)
def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
def dns_done(chan, data, method, sock, srcip, dstip, mux):
debug3('dns_done: channel=%d src=%r dst=%r\n' % (chan,srcip,dstip))
del mux.channels[chan]
del dnsreqs[chan]
if method == "tproxy":
debug3('doing send from %r to %r\n' % (srcip,dstip,))
sender = socket.socket(sock.family, socket.SOCK_DGRAM)
sender.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sender.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
sender.bind(srcip)
sender.sendto(data, dstip)
sender.close()
else:
debug3('doing sendto %r\n' % (dstip,))
sock.sendto(data, dstip)
def ondns(listener, method, mux, handlers):
now = time.time()
srcip, dstip, data = recv_udp(listener, 4096)
if method == "tproxy" and not dstip:
debug1("-- ignored UDP from %r: couldn't determine destination IP address\n" % (srcip,))
return
debug1('DNS request from %r to %r: %d bytes\n' % (srcip,dstip,len(data)))
chan = mux.next_channel()
dnsreqs[chan] = now+30
mux.send(chan, ssnet.CMD_DNS_REQ, data)
mux.channels[chan] = lambda cmd,data: dns_done(chan, data, method, listener, srcip=dstip, dstip=srcip, mux=mux)
expire_connections(now, mux)
def _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename, python, latency_control,
dns_listener, method, seed_hosts, auto_nets,
syslog, daemon):
handlers = []
@ -303,7 +448,7 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
try:
(serverproc, serversock) = ssh.connect(ssh_cmd, remotename, python,
stderr=ssyslog._p and ssyslog._p.stdin,
options=dict(latency_control=latency_control))
options=dict(latency_control=latency_control, method=method))
except socket.error, e:
if e.args[0] == errno.EPIPE:
raise Fatal("failed to establish ssh session (1)")
@ -372,6 +517,9 @@ def _main(tcp_listener, fw, ssh_cmd, remotename, python, latency_control,
tcp_listener.add_handler(handlers, onaccept_tcp, method, mux)
if udp_listener:
udp_listener.add_handler(handlers, onaccept_udp, method, mux)
if dns_listener:
dns_listener.add_handler(handlers, ondns, method, mux)
@ -404,6 +552,23 @@ def main(listenip_v6, listenip_v4,
return 5
debug1('Starting sshuttle proxy.\n')
if recvmsg is not None:
debug1("recvmsg %s support enabled.\n"%recvmsg)
if method == "tproxy":
if recvmsg is not None:
debug1("tproxy UDP support enabled.\n")
udp = True
else:
debug1("tproxy UDP support requires recvmsg function.\n")
udp = False
if dns and recvmsg is None:
debug1("tproxy DNS support requires recvmsg function.\n")
dns = False
else:
debug1("UDP support requires tproxy; disabling UDP.\n")
udp = False
if listenip_v6 and listenip_v6[1] and listenip_v4 and listenip_v4[1]:
# if both ports given, no need to search for a spare port
ports = [ 0, ]
@ -422,6 +587,12 @@ def main(listenip_v6, listenip_v4,
tcp_listener = MultiListener()
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if udp:
udp_listener = MultiListener(socket.SOCK_DGRAM)
udp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
else:
udp_listener = None
if listenip_v6 and listenip_v6[1]:
lv6 = listenip_v6
redirectport_v6 = lv6[1]
@ -444,6 +615,8 @@ def main(listenip_v6, listenip_v4,
try:
tcp_listener.bind(lv6, lv4)
if udp_listener:
udp_listener.bind(lv6, lv4)
bound = True
break
except socket.error, e:
@ -457,6 +630,8 @@ def main(listenip_v6, listenip_v4,
raise last_e
tcp_listener.listen(10)
tcp_listener.print_listening("TCP redirector")
if udp_listener:
udp_listener.print_listening("UDP redirector")
bound = False
if dns:
@ -466,7 +641,6 @@ def main(listenip_v6, listenip_v4,
for port in ports:
debug2(' %d' % port)
dns_listener = MultiListener(socket.SOCK_DGRAM)
dns_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if listenip_v6:
lv6 = (listenip_v6[0],port)
@ -501,13 +675,21 @@ def main(listenip_v6, listenip_v4,
dnsport_v4 = 0
dns_listener = None
fw = FirewallClient(redirectport_v6, redirectport_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method)
fw = FirewallClient(redirectport_v6, redirectport_v4, subnets_include, subnets_exclude, dnsport_v6, dnsport_v4, method, udp)
if fw.method == "tproxy":
tcp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
if udp_listener:
udp_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
udp_listener.v4.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
udp_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1)
if dns_listener:
dns_listener.setsockopt(socket.SOL_IP, IP_TRANSPARENT, 1)
dns_listener.v4.setsockopt(socket.SOL_IP, IP_RECVORIGDSTADDR, 1)
dns_listener.v6.setsockopt(SOL_IPV6, IPV6_RECVORIGDSTADDR, 1)
try:
return _main(tcp_listener, fw, ssh_cmd, remotename,
return _main(tcp_listener, udp_listener, fw, ssh_cmd, remotename,
python, latency_control, dns_listener,
fw.method, seed_hosts, auto_nets, syslog,
daemon)

View File

@ -69,10 +69,12 @@ def _ipt_ttl(family, *args):
# multiple copies shouldn't have overlapping subnets, or only the most-
# recently-started one will win (because we use "-I OUTPUT 1" instead of
# "-A OUTPUT").
def do_iptables_nat(port, dnsport, family, subnets):
def do_iptables_nat(port, dnsport, family, subnets, udp):
# only ipv4 supported with NAT
if family != socket.AF_INET:
raise Exception('Address family "%s" unsupported by nat method'%family_to_string(family))
if udp:
raise Exception("UDP not supported by nat method")
table = "nat"
def ipt(*args):
@ -122,7 +124,7 @@ def do_iptables_nat(port, dnsport, family, subnets):
'--to-ports', str(dnsport))
def do_iptables_tproxy(port, dnsport, family, subnets):
def do_iptables_tproxy(port, dnsport, family, subnets, udp):
if family not in [socket.AF_INET, socket.AF_INET6]:
raise Exception('Address family "%s" unsupported by tproxy method'%family_to_string(family))
@ -164,6 +166,21 @@ def do_iptables_tproxy(port, dnsport, family, subnets):
ipt('-A', divert_chain, '-j', 'ACCEPT')
ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'tcp', '-p', 'tcp')
if subnets and udp:
ipt('-A', tproxy_chain, '-m', 'socket', '-j', divert_chain,
'-m', 'udp', '-p', 'udp')
if dnsport:
nslist = resolvconf_nameservers()
for f,ip in filter(lambda i: i[0]==family, nslist):
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53')
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1',
'--dest', '%s/32' % ip,
'-m', 'udp', '-p', 'udp', '--dport', '53',
'--on-port', str(dnsport))
if subnets:
for f,swidth,sexclude,snet in sorted(subnets, key=lambda s: s[1], reverse=True):
if sexclude:
@ -182,6 +199,22 @@ def do_iptables_tproxy(port, dnsport, family, subnets):
'-m', 'tcp', '-p', 'tcp',
'--on-port', str(port))
if sexclude and udp:
ipt('-A', mark_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'udp', '-p', 'udp')
ipt('-A', tproxy_chain, '-j', 'RETURN',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'udp', '-p', 'udp')
elif udp:
ipt('-A', mark_chain, '-j', 'MARK', '--set-mark', '1',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'udp', '-p', 'udp')
ipt('-A', tproxy_chain, '-j', 'TPROXY', '--tproxy-mark', '0x1/0x1',
'--dest', '%s/%s' % (snet,swidth),
'-m', 'udp', '-p', 'udp',
'--on-port', str(port))
def ipfw_rule_exists(n):
argv = ['ipfw', 'list']
@ -288,10 +321,12 @@ def ipfw(*args):
raise Fatal('%r returned %d' % (argv, rv))
def do_ipfw(port, dnsport, family, subnets):
def do_ipfw(port, dnsport, family, subnets, udp):
# IPv6 not supported
if family not in [socket.AF_INET, ]:
raise Exception('Address family "%s" unsupported by ipfw method'%family_to_string(family))
if udp:
raise Exception("UDP not supported by ipfw method")
sport = str(port)
xsport = str(port+1)
@ -454,7 +489,7 @@ def restore_etc_hosts(port):
# exit. In case that fails, it's not the end of the world; future runs will
# supercede it in the transproxy list, at least, so the leftover rules
# are hopefully harmless.
def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, syslog):
def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, udp, syslog):
assert(port_v6 >= 0)
assert(port_v6 <= 65535)
assert(port_v4 >= 0)
@ -529,13 +564,13 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, syslog):
subnets_v6 = filter(lambda i: i[0]==socket.AF_INET6, subnets)
if port_v6:
do_wait = do_it(port_v6, dnsport_v6, socket.AF_INET6, subnets_v6)
do_wait = do_it(port_v6, dnsport_v6, socket.AF_INET6, subnets_v6, udp)
elif len(subnets_v6) > 0:
debug1("IPv6 subnets defined but IPv6 disabled\n")
subnets_v4 = filter(lambda i: i[0]==socket.AF_INET, subnets)
if port_v4:
do_wait = do_it(port_v4, dnsport_v4, socket.AF_INET, subnets_v4)
do_wait = do_it(port_v4, dnsport_v4, socket.AF_INET, subnets_v4, udp)
elif len(subnets_v4) > 0:
debug1('IPv4 subnets defined but IPv4 disabled\n')
@ -567,6 +602,8 @@ def main(port_v6, port_v4, dnsport_v6, dnsport_v4, method, syslog):
debug1('firewall manager: undoing changes.\n')
except:
pass
if port_v6:
do_it(port_v6, 0, socket.AF_INET6, [], udp)
if port_v4:
do_it(port_v4, 0, socket.AF_INET, [])
do_it(port_v4, 0, socket.AF_INET, [], udp)
restore_etc_hosts(port_v6 or port_v4)

View File

@ -123,11 +123,11 @@ try:
server.latency_control = opt.latency_control
sys.exit(server.main())
elif opt.firewall:
if len(extra) != 5:
o.fatal('exactly five arguments expected')
if len(extra) != 6:
o.fatal('exactly six arguments expected')
sys.exit(firewall.main(int(extra[0]), int(extra[1]),
int(extra[2]), int(extra[3]),
extra[4], opt.syslog))
extra[4], int(extra[5]), opt.syslog))
elif opt.hostwatch:
sys.exit(hostwatch.hw_main(extra))
else:

View File

@ -163,6 +163,35 @@ class DnsProxy(Handler):
self.ok = False
class UdpProxy(Handler):
def __init__(self, mux, chan, family):
sock = socket.socket(family, socket.SOCK_DGRAM)
Handler.__init__(self, [sock])
self.timeout = time.time()+30
self.mux = mux
self.chan = chan
self.sock = sock
if family == socket.AF_INET:
self.sock.setsockopt(socket.SOL_IP, socket.IP_TTL, 42)
def send(self, dstip, data):
debug2('UDP: sending to %r port %d\n' % dstip)
try:
self.sock.sendto(data,dstip)
except socket.error, e:
log('UDP send to %r port %d: %s\n' % (dstip[0], dstip[1], e))
return
def callback(self):
try:
data,peer = self.sock.recvfrom(4096)
except socket.error, e:
log('UDP recv from %r port %d: %s\n' % (peer[0], peer[1], e))
return
debug2('UDP response: %d bytes\n' % len(data))
hdr = "%s,%r,"%(peer[0], peer[1])
self.mux.send(self.chan, ssnet.CMD_UDP_DATA, hdr+data)
def main():
if helpers.verbose >= 1:
helpers.logprefix = ' s: '
@ -231,6 +260,34 @@ def main():
dnshandlers[channel] = h
mux.got_dns_req = dns_req
udphandlers = {}
def udp_req(channel, cmd, data):
debug2('Incoming UDP request channel=%d, cmd=%d\n' % (channel,cmd))
if cmd == ssnet.CMD_UDP_DATA:
(dstip,dstport,data) = data.split(",",2)
dstport = int(dstport)
debug2('is incoming UDP data. %r %d.\n' % (dstip,dstport))
h = udphandlers[channel]
h.send((dstip,dstport),data)
elif cmd == ssnet.CMD_UDP_CLOSE:
debug2('is incoming UDP close\n')
h = udphandlers[channel]
h.ok = False
del mux.channels[channel]
def udp_open(channel, data):
debug2('Incoming UDP open.\n')
family = int(data)
mux.channels[channel] = lambda cmd, data: udp_req(channel, cmd, data)
if channel in udphandlers:
raise Fatal('UDP connection channel %d already open'%channel)
else:
h = UdpProxy(mux, channel, family)
handlers.append(h)
udphandlers[channel] = h
mux.got_udp_open = udp_open
while mux.ok:
if hw.pid:
assert(hw.pid > 0)
@ -250,3 +307,8 @@ def main():
debug3('expiring dnsreqs channel=%d\n' % channel)
del dnshandlers[channel]
h.ok = False
for channel,h in udphandlers.items():
if not h.ok:
debug3('expiring UDP channel=%d\n' % channel)
del udphandlers[channel]
h.ok = False

View File

@ -25,6 +25,9 @@ CMD_HOST_REQ = 0x4208
CMD_HOST_LIST = 0x4209
CMD_DNS_REQ = 0x420a
CMD_DNS_RESPONSE = 0x420b
CMD_UDP_OPEN = 0x420c
CMD_UDP_DATA = 0x420d
CMD_UDP_CLOSE = 0x420e
cmd_to_name = {
CMD_EXIT: 'EXIT',
@ -39,6 +42,9 @@ cmd_to_name = {
CMD_HOST_LIST: 'HOST_LIST',
CMD_DNS_REQ: 'DNS_REQ',
CMD_DNS_RESPONSE: 'DNS_RESPONSE',
CMD_UDP_OPEN: 'UDP_OPEN',
CMD_UDP_DATA: 'UDP_DATA',
CMD_UDP_CLOSE: 'UDP_CLOSE',
}
@ -318,6 +324,7 @@ class Mux(Handler):
self.rsock = rsock
self.wsock = wsock
self.new_channel = self.got_dns_req = self.got_routes = None
self.got_udp_open = self.got_udp_data = self.got_udp_close = None
self.got_host_req = self.got_host_list = None
self.channels = {}
self.chani = 0
@ -383,6 +390,10 @@ class Mux(Handler):
assert(not self.channels.get(channel))
if self.got_dns_req:
self.got_dns_req(channel, data)
elif cmd == CMD_UDP_OPEN:
assert(not self.channels.get(channel))
if self.got_udp_open:
self.got_udp_open(channel, data)
elif cmd == CMD_ROUTES:
if self.got_routes:
self.got_routes(data)