more improvements windows support

This commit is contained in:
nom3ad 2022-09-07 12:26:21 +05:30 committed by Brian May
parent 2c74476124
commit bd2f960743
4 changed files with 75 additions and 60 deletions

View File

@ -323,8 +323,8 @@ class FirewallClient:
try: try:
line = self.pfile.readline() line = self.pfile.readline()
except ConnectionResetError: except IOError:
# happens in Windows, when subprocess exists # happens when firewall subprocess exists
line='' line=''
rv = self.p.poll() # Check if process is still running rv = self.p.poll() # Check if process is still running

View File

@ -12,7 +12,7 @@ import io
import sshuttle.ssyslog as ssyslog import sshuttle.ssyslog as ssyslog
import sshuttle.helpers as helpers import sshuttle.helpers as helpers
from sshuttle.helpers import is_admin_user, log, debug1, debug2, Fatal from sshuttle.helpers import is_admin_user, log, debug1, debug2, debug3, Fatal
from sshuttle.methods import get_auto_method, get_method from sshuttle.methods import get_auto_method, get_method
HOSTSFILE = '/etc/hosts' HOSTSFILE = '/etc/hosts'
@ -214,9 +214,11 @@ def main(method_name, syslog):
try: try:
line = stdin.readline(128) line = stdin.readline(128)
if not line: if not line:
return # parent died; nothing to do # parent probably exited
except ConnectionResetError: return
# On windows, this is thrown when parent process closes it's socket pair end except IOError:
# On windows, this ConnectionResetError is thrown when parent process closes it's socket pair end
debug3('read from stdin failed: %s' % (e,))
return return
subnets = [] subnets = []
@ -322,21 +324,26 @@ def main(method_name, syslog):
socket.AF_INET, subnets_v4, udp, socket.AF_INET, subnets_v4, udp,
user, group, tmark) user, group, tmark)
if sys.platform != 'win32':
flush_systemd_dns_cache() flush_systemd_dns_cache()
stdout.write('STARTED\n')
try: try:
stdout.write('STARTED\n')
stdout.flush() stdout.flush()
except IOError: except IOError as e:
# the parent process died for some reason; he's surely been loud debug3('write to stdout failed: %s' % (e,))
# enough, so no reason to report another error
return return
# Now we wait until EOF or any other kind of exception. We need # Now we wait until EOF or any other kind of exception. We need
# to stay running so that we don't need a *second* password # to stay running so that we don't need a *second* password
# authentication at shutdown time - that cleanup is important! # authentication at shutdown time - that cleanup is important!
while 1: while 1:
try:
line = stdin.readline(128) line = stdin.readline(128)
except IOError as e:
debug3('read from stdin failed: %s' % (e,))
return
if line.startswith('HOST '): if line.startswith('HOST '):
(name, ip) = line[5:].strip().split(',', 1) (name, ip) = line[5:].strip().split(',', 1)
hostmap[name] = ip hostmap[name] = ip
@ -385,6 +392,7 @@ def main(method_name, syslog):
except Exception: except Exception:
debug2('An error occurred, ignoring it.') debug2('An error occurred, ignoring it.')
if sys.platform != 'win32':
try: try:
flush_systemd_dns_cache() flush_systemd_dns_cache()
except Exception: except Exception:

View File

@ -26,7 +26,7 @@ ConnectionTuple = namedtuple(
) )
MAX_CONNECTIONS = 2 #_000 WINDIVERT_MAX_CONNECTIONS = 10_000
class IPProtocol(IntEnum): class IPProtocol(IntEnum):
TCP = socket.IPPROTO_TCP TCP = socket.IPPROTO_TCP
@ -72,7 +72,7 @@ class ConnTrack:
self.shm_list = shared_memory.ShareableList([bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name) self.shm_list = shared_memory.ShareableList([bytes(self.struct_full_tuple.size) for _ in range(max_connections)], name=name)
self.is_owner = True self.is_owner = True
self.next_slot = 0 self.next_slot = 0
self.used_slotes = set() self.used_slots = set()
self.rlock = threading.RLock() self.rlock = threading.RLock()
except FileExistsError: except FileExistsError:
self.is_owner = False self.is_owner = False
@ -85,18 +85,18 @@ class ConnTrack:
def add(self, proto, src_addr, src_port, dst_addr, dst_port, state): def add(self, proto, src_addr, src_port, dst_addr, dst_port, state):
if not self.is_owner: if not self.is_owner:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
if len(self.used_slotes) >= self.max_connections: if len(self.used_slots) >= self.max_connections:
raise RuntimeError(f"No slot avaialble in ConnTrack {len(self.used_slotes)}/{self.max_connections}") raise RuntimeError(f"No slot available in ConnTrack {len(self.used_slots)}/{self.max_connections}")
if self.get(proto, src_addr, src_port): if self.get(proto, src_addr, src_port):
return return
for _ in range(self.max_connections): for _ in range(self.max_connections):
if self.next_slot not in self.used_slotes: if self.next_slot not in self.used_slots:
break break
self.next_slot = (self.next_slot +1) % self.max_connections self.next_slot = (self.next_slot +1) % self.max_connections
else: else:
raise RuntimeError("No slot avaialble in ConnTrack") # should not be here raise RuntimeError("No slot available in ConnTrack") # should not be here
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
dst_addr = ipaddress.ip_address(dst_addr) dst_addr = ipaddress.ip_address(dst_addr)
@ -106,9 +106,9 @@ class ConnTrack:
entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state) entry = (proto, ip_version, src_addr.packed, src_port, dst_addr.packed, dst_port, state_epoch, state)
packed = self.struct_full_tuple.pack(*entry) packed = self.struct_full_tuple.pack(*entry)
self.shm_list[self.next_slot] = packed self.shm_list[self.next_slot] = packed
self.used_slotes.add(self.next_slot) self.used_slots.add(self.next_slot)
proto = IPProtocol(proto) proto = IPProtocol(proto)
debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slotes)}") debug3(f"ConnTrack: added connection ({proto.name} {src_addr}:{src_port}->{dst_addr}:{dst_port} @{state_epoch}:{state.name}) to slot={self.next_slot} | #ActiveConn={len(self.used_slots)}")
@synchronized_method('rlock') @synchronized_method('rlock')
def update(self, proto, src_addr, src_port, state): def update(self, proto, src_addr, src_port, state):
@ -116,14 +116,14 @@ class ConnTrack:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for i in self.used_slotes: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
state_epoch = int(time.time()) state_epoch = int(time.time())
self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state) self.shm_list[i] = self.shm_list[i][:-5] + self.struct_state_tuple.pack(state_epoch, state)
debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slotes)}") debug3(f"ConnTrack: updated connection ({proto.name} {src_addr}:{src_port} @{state_epoch}:{state.name}) from slot={i} | #ActiveConn={len(self.used_slots)}")
return self._unpack(self.shm_list[i]) return self._unpack(self.shm_list[i])
else: else:
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slotes)}") debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to update to {state.name} | #ActiveConn={len(self.used_slots)}")
@synchronized_method('rlock') @synchronized_method('rlock')
def remove(self, proto, src_addr, src_port): def remove(self, proto, src_addr, src_port):
@ -131,15 +131,15 @@ class ConnTrack:
raise RuntimeError("Only owner can mutate ConnTrack") raise RuntimeError("Only owner can mutate ConnTrack")
src_addr = ipaddress.ip_address(src_addr) src_addr = ipaddress.ip_address(src_addr)
packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port) packed = self.struct_src_tuple.pack(proto, src_addr.version, src_addr.packed, src_port)
for i in self.used_slotes: for i in self.used_slots:
if self.shm_list[i].startswith(packed): if self.shm_list[i].startswith(packed):
conn = self._unpack(self.shm_list[i]) conn = self._unpack(self.shm_list[i])
self.shm_list[i] = b'' self.shm_list[i] = b''
self.used_slotes.remove(i) self.used_slots.remove(i)
debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slotes)}") debug3(f"ConnTrack: removed connection ({proto.name} src={src_addr}:{src_port}) from slot={i} | #ActiveConn={len(self.used_slots)}")
return conn return conn
else: else:
debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slotes)}") debug3(f"ConnTrack: connection ({proto.name} src={src_addr}:{src_port}) is not found to remove | #ActiveConn={len(self.used_slots)}")
def get(self, proto, src_addr, src_port): def get(self, proto, src_addr, src_port):
@ -156,7 +156,7 @@ class ConnTrack:
return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state)) return ConnectionTuple(IPProtocol(proto), ip_version, src_addr, src_port, dst_addr, dst_port, state_epoch, ConnState(state))
def __repr__(self): def __repr__(self):
return f"<ConnTrack(n={len(self.used_slotes) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>" return f"<ConnTrack(n={len(self.used_slots) if self.is_owner else '?'}, cap={len(self.shm_list)}, owner={self.is_owner})>"
class Method(BaseMethod): class Method(BaseMethod):
@ -164,24 +164,24 @@ class Method(BaseMethod):
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, def setup_firewall(self, port, dnsport, nslist, family, subnets, udp,
user, tmark): user, tmark):
log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}") log( f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
self.conntrack = ConnTrack(f'sshttle-windiver-{os.getppid()}', MAX_CONNECTIONS) self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getppid()}', WINDIVERT_MAX_CONNECTIONS)
proxy_addr = "10.0.2.15" proxy_addr = "10.0.2.15"
subnet_addreses = [] subnet_addresses = []
for (_, mask, exclude, network_addr, fport, lport) in subnets: for (_, mask, exclude, network_addr, fport, lport) in subnets:
if exclude: if exclude:
continue continue
assert fport == 0, 'custom port range not supported' assert fport == 0, 'custom port range not supported'
assert lport == 0, 'custom port range not supported' assert lport == 0, 'custom port range not supported'
subnet_addreses.append("%s/%s" % (network_addr, mask)) subnet_addresses.append("%s/%s" % (network_addr, mask))
debug2("setup_firewall() subnet_addreses=%s proxy_addr=%s:%s" % (subnet_addreses,proxy_addr,port)) debug2("setup_firewall() subnet_addresses=%s proxy_addr=%s:%s" % (subnet_addresses,proxy_addr,port))
# check permission # check permission
with pydivert.WinDivert('false'): with pydivert.WinDivert('false'):
pass pass
threading.Thread(name='outbound_divert', target=self._outbound_divert, args=(subnet_addreses, proxy_addr, port), daemon=True).start() threading.Thread(name='outbound_divert', target=self._outbound_divert, args=(subnet_addresses, proxy_addr, port), daemon=True).start()
threading.Thread(name='inbound_divert', target=self._inbound_divert, args=(proxy_addr, port), daemon=True).start() threading.Thread(name='inbound_divert', target=self._inbound_divert, args=(proxy_addr, port), daemon=True).start()
def restore_firewall(self, port, family, udp, user): def restore_firewall(self, port, family, udp, user):
@ -196,7 +196,7 @@ class Method(BaseMethod):
def get_tcp_dstip(self, sock): def get_tcp_dstip(self, sock):
if not hasattr(self, 'conntrack'): if not hasattr(self, 'conntrack'):
self.conntrack = ConnTrack(f'sshttle-windiver-{os.getpid()}') self.conntrack = ConnTrack(f'sshuttle-windivert-{os.getpid()}')
src_addr , src_port = sock.getpeername() src_addr , src_port = sock.getpeername()
c = self.conntrack.get(IPProtocol.TCP , src_addr, src_port) c = self.conntrack.get(IPProtocol.TCP , src_addr, src_port)
@ -245,14 +245,14 @@ class Method(BaseMethod):
with pydivert.WinDivert(filter) as w: with pydivert.WinDivert(filter) as w:
for pkt in w: for pkt in w:
debug3("<<< " + repr_pkt(pkt)) debug3("<<< " + repr_pkt(pkt))
if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK Conenction established if pkt.tcp.syn and pkt.tcp.ack: # SYN+ACK connection established
conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV) conn = self.conntrack.update(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port, ConnState.TCP_SYN_ACK_RECV)
elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown elif pkt.tcp.rst or (pkt.tcp.fin and pkt.tcp.ack): # RST or FIN+ACK Connection teardown
conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.remove(IPProtocol.TCP, pkt.dst_addr, pkt.dst_port)
else: else:
conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port) conn = self.conntrack.get(socket.IPPROTO_TCP, pkt.dst_addr, pkt.dst_port)
if not conn: if not conn:
debug2("Unexpcted packet: " + repr_pkt(pkt)) debug2("Unexpected packet: " + repr_pkt(pkt))
continue continue
pkt.ipv4.src_addr = conn.dst_addr pkt.ipv4.src_addr = conn.dst_addr
pkt.tcp.src_port = conn.dst_port pkt.tcp.src_port = conn.dst_port

View File

@ -12,7 +12,7 @@ import ipaddress
from urllib.parse import urlparse from urllib.parse import urlparse
import sshuttle.helpers as helpers import sshuttle.helpers as helpers
from sshuttle.helpers import debug2, which, get_path, Fatal from sshuttle.helpers import debug2, debug3, which, get_path, Fatal
def get_module_source(name): def get_module_source(name):
@ -224,24 +224,31 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
pstdout = ssubprocess.PIPE pstdout = ssubprocess.PIPE
def get_serversock(): def get_serversock():
import threading import threading
def steam_stdout_to_sock():
while True: def stream_stdout_to_sock():
data = p.stdout.read(1) try:
if not data: fd = p.stdout.fileno()
debug2("EOF on ssh process stdout. Process probably exited") for data in iter(lambda:os.read(fd, 16384), b''):
break s1.sendall(data)
n = s1.sendall(data) debug3(f"<<<<< p.stdout.read() {len(data)} {data[:min(32,len(data))]}...")
print("<<<<< p.stdout.read()", len(data), '->', n, data[:min(32,len(data))]) finally:
debug2("Thread 'stream_stdout_to_sock' exiting")
s1.close()
p.terminate()
def stream_sock_to_stdin(): def stream_sock_to_stdin():
while True: try:
data = s1.recv(16384) for data in iter(lambda:s1.recv(16384), b''):
if not data: debug3(f">>>>> p.stdout.write() {len(data)} {data[:min(32,len(data))]}...")
print(">>>>>> EOF stream_sock_to_stdin") while data:
break
n = p.stdin.write(data) n = p.stdin.write(data)
print(">>>>>> s1.recv()", len(data) , "->" , n , data[:min(32,len(data))]) data = data[n:]
p.communicate finally:
threading.Thread(target=steam_stdout_to_sock, name='steam_stdout_to_sock', daemon=True).start() debug2("Thread 'stream_sock_to_stdin' exiting")
s1.close()
p.terminate()
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start() threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
# s2.setblocking(False) # s2.setblocking(False)
return s2 return s2