mirror of
https://github.com/sshuttle/sshuttle.git
synced 2025-05-11 03:24:33 +02:00
fix windows CRLF issue on stdin/stdout
This commit is contained in:
parent
900acc3ac7
commit
4a84ad3be6
@ -307,7 +307,7 @@ class FirewallClient:
|
|||||||
|
|
||||||
def get_pfile():
|
def get_pfile():
|
||||||
if can_use_stdio:
|
if can_use_stdio:
|
||||||
self.p.stdin.write(b'STDIO:\n')
|
self.p.stdin.write(b'COM_STDIO:\n')
|
||||||
self.p.stdin.flush()
|
self.p.stdin.flush()
|
||||||
|
|
||||||
class RWPair:
|
class RWPair:
|
||||||
@ -334,7 +334,7 @@ class FirewallClient:
|
|||||||
socket_share_data = s1.share(self.p.pid)
|
socket_share_data = s1.share(self.p.pid)
|
||||||
s1.close()
|
s1.close()
|
||||||
socket_share_data_b64 = base64.b64encode(socket_share_data)
|
socket_share_data_b64 = base64.b64encode(socket_share_data)
|
||||||
self.p.stdin.write(b'SOCKETSHARE:' + socket_share_data_b64 + b'\n')
|
self.p.stdin.write(b'COM_SOCKETSHARE:' + socket_share_data_b64 + b'\n')
|
||||||
self.p.stdin.flush()
|
self.p.stdin.flush()
|
||||||
return s2.makefile('rwb')
|
return s2.makefile('rwb')
|
||||||
try:
|
try:
|
||||||
|
@ -84,7 +84,7 @@ def firewall_exit(signum, frame):
|
|||||||
# the typical exit process as described above.
|
# the typical exit process as described above.
|
||||||
global sshuttle_pid
|
global sshuttle_pid
|
||||||
if sshuttle_pid:
|
if sshuttle_pid:
|
||||||
debug1("Relaying interupt signal to sshuttle process %d\n" % sshuttle_pid)
|
debug1("Relaying interupt signal to sshuttle process %d" % sshuttle_pid)
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
sig = signal.CTRL_C_EVENT
|
sig = signal.CTRL_C_EVENT
|
||||||
else:
|
else:
|
||||||
@ -115,7 +115,7 @@ def _setup_daemon_for_unix_like():
|
|||||||
# setsid() fails if sudo is configured with the use_pty option.
|
# setsid() fails if sudo is configured with the use_pty option.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return sys.stdin, sys.stdout
|
return sys.stdin.buffer, sys.stdout.buffer
|
||||||
|
|
||||||
|
|
||||||
def _setup_daemon_for_windows():
|
def _setup_daemon_for_windows():
|
||||||
@ -125,9 +125,9 @@ def _setup_daemon_for_windows():
|
|||||||
signal.signal(signal.SIGTERM, firewall_exit)
|
signal.signal(signal.SIGTERM, firewall_exit)
|
||||||
signal.signal(signal.SIGINT, firewall_exit)
|
signal.signal(signal.SIGINT, firewall_exit)
|
||||||
|
|
||||||
socket_share_data_prefix = 'SOCKETSHARE:'
|
socket_share_data_prefix = b'COM_SOCKETSHARE:'
|
||||||
line = sys.stdin.readline().strip()
|
line = sys.stdin.buffer.readline().strip()
|
||||||
if line.startswith('SOCKETSHARE:'):
|
if line.startswith(socket_share_data_prefix):
|
||||||
debug3('Using shared socket for communicating with sshuttle client process')
|
debug3('Using shared socket for communicating with sshuttle client process')
|
||||||
socket_share_data_b64 = line[len(socket_share_data_prefix):]
|
socket_share_data_b64 = line[len(socket_share_data_prefix):]
|
||||||
socket_share_data = base64.b64decode(socket_share_data_b64)
|
socket_share_data = base64.b64decode(socket_share_data_b64)
|
||||||
@ -135,12 +135,12 @@ def _setup_daemon_for_windows():
|
|||||||
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
|
sys.stdin = io.TextIOWrapper(sock.makefile('rb', buffering=0))
|
||||||
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
|
sys.stdout = io.TextIOWrapper(sock.makefile('wb', buffering=0), write_through=True)
|
||||||
sock.close()
|
sock.close()
|
||||||
elif line.startswith("STDIO:"):
|
elif line.startswith(b"COM_STDIO:"):
|
||||||
debug3('Using inherited stdio for communicating with sshuttle client process')
|
debug3('Using inherited stdio for communicating with sshuttle client process')
|
||||||
else:
|
else:
|
||||||
raise Fatal("Unexpected stdin: " + line)
|
raise Fatal("Unexpected stdin: " + line)
|
||||||
|
|
||||||
return sys.stdin, sys.stdout
|
return sys.stdin.buffer, sys.stdout.buffer
|
||||||
|
|
||||||
|
|
||||||
# Isolate function that needs to be replaced for tests
|
# Isolate function that needs to be replaced for tests
|
||||||
@ -221,33 +221,43 @@ def main(method_name, syslog):
|
|||||||
"PATH." % method_name)
|
"PATH." % method_name)
|
||||||
|
|
||||||
debug1('ready method name %s.' % method.name)
|
debug1('ready method name %s.' % method.name)
|
||||||
stdout.write('READY %s\n' % method.name)
|
stdout.write(('READY %s\n' % method.name).encode('ASCII'))
|
||||||
stdout.flush()
|
stdout.flush()
|
||||||
|
|
||||||
# we wait until we get some input before creating the rules. That way,
|
|
||||||
# sshuttle can launch us as early as possible (and get sudo password
|
def _read_next_string_line():
|
||||||
# authentication as early in the startup process as possible).
|
|
||||||
try:
|
try:
|
||||||
line = stdin.readline(128)
|
line = stdin.readline(128)
|
||||||
if not line:
|
if not line:
|
||||||
return # parent probably exited
|
return # parent probably exited
|
||||||
except ConnectionResetError as e:
|
return line.decode('ASCII').strip()
|
||||||
|
except IOError as e:
|
||||||
|
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
|
||||||
|
debug3('read from stdin failed: %s' % (e,))
|
||||||
|
return
|
||||||
|
# we wait until we get some input before creating the rules. That way,
|
||||||
|
# sshuttle can launch us as early as possible (and get sudo password
|
||||||
|
# authentication as early in the startup process as possible).
|
||||||
|
try:
|
||||||
|
line = _read_next_string_line()
|
||||||
|
if not line:
|
||||||
|
return # parent probably exited
|
||||||
|
except IOError as e:
|
||||||
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
|
# On windows, ConnectionResetError is thrown when parent process closes it's socket pair end
|
||||||
debug3('read from stdin failed: %s' % (e,))
|
debug3('read from stdin failed: %s' % (e,))
|
||||||
return
|
return
|
||||||
|
|
||||||
subnets = []
|
subnets = []
|
||||||
if line != 'ROUTES\n':
|
if line != 'ROUTES':
|
||||||
raise Fatal('expected ROUTES but got %r' % line)
|
raise Fatal('expected ROUTES but got %r' % line)
|
||||||
while 1:
|
while 1:
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('expected route but got %r' % line)
|
raise Fatal('expected route but got %r' % line)
|
||||||
elif line.startswith("NSLIST\n"):
|
elif line.startswith("NSLIST"):
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
(family, width, exclude, ip, fport, lport) = \
|
(family, width, exclude, ip, fport, lport) = line.split(',', 5)
|
||||||
line.strip().split(',', 5)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Fatal('expected route or NSLIST but got %r' % line)
|
raise Fatal('expected route or NSLIST but got %r' % line)
|
||||||
subnets.append((
|
subnets.append((
|
||||||
@ -260,16 +270,16 @@ def main(method_name, syslog):
|
|||||||
debug2('Got subnets: %r' % subnets)
|
debug2('Got subnets: %r' % subnets)
|
||||||
|
|
||||||
nslist = []
|
nslist = []
|
||||||
if line != 'NSLIST\n':
|
if line != 'NSLIST':
|
||||||
raise Fatal('expected NSLIST but got %r' % line)
|
raise Fatal('expected NSLIST but got %r' % line)
|
||||||
while 1:
|
while 1:
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line:
|
||||||
raise Fatal('expected nslist but got %r' % line)
|
raise Fatal('expected nslist but got %r' % line)
|
||||||
elif line.startswith("PORTS "):
|
elif line.startswith("PORTS "):
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
(family, ip) = line.strip().split(',', 1)
|
(family, ip) = line.split(',', 1)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Fatal('expected nslist or PORTS but got %r' % line)
|
raise Fatal('expected nslist or PORTS but got %r' % line)
|
||||||
nslist.append((int(family), ip))
|
nslist.append((int(family), ip))
|
||||||
@ -299,15 +309,13 @@ def main(method_name, syslog):
|
|||||||
debug2('Got ports: %d,%d,%d,%d'
|
debug2('Got ports: %d,%d,%d,%d'
|
||||||
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
% (port_v6, port_v4, dnsport_v6, dnsport_v4))
|
||||||
|
|
||||||
line = stdin.readline(128)
|
line = _read_next_string_line()
|
||||||
if not line:
|
if not line or not line.startswith("GO "):
|
||||||
raise Fatal('expected GO but got %r' % line)
|
|
||||||
elif not line.startswith("GO "):
|
|
||||||
raise Fatal('expected GO but got %r' % line)
|
raise Fatal('expected GO but got %r' % line)
|
||||||
|
|
||||||
_, _, args = line.partition(" ")
|
_, _, args = line.partition(" ")
|
||||||
global sshuttle_pid
|
global sshuttle_pid
|
||||||
udp, user, group, tmark, sshuttle_pid = args.strip().split(" ", 4)
|
udp, user, group, tmark, sshuttle_pid = args.split(" ", 4)
|
||||||
udp = bool(int(udp))
|
udp = bool(int(udp))
|
||||||
sshuttle_pid = int(sshuttle_pid)
|
sshuttle_pid = int(sshuttle_pid)
|
||||||
if user == '-':
|
if user == '-':
|
||||||
@ -350,7 +358,7 @@ def main(method_name, syslog):
|
|||||||
flush_systemd_dns_cache()
|
flush_systemd_dns_cache()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stdout.write('STARTED\n')
|
stdout.write(b'STARTED\n')
|
||||||
stdout.flush()
|
stdout.flush()
|
||||||
except IOError as e: # the parent process probably died
|
except IOError as e: # the parent process probably died
|
||||||
debug3('write to stdout failed: %s' % (e,))
|
debug3('write to stdout failed: %s' % (e,))
|
||||||
@ -360,13 +368,11 @@ def main(method_name, syslog):
|
|||||||
# to stay running so that we don't need a *second* password
|
# to stay running so that we don't need a *second* password
|
||||||
# authentication at shutdown time - that cleanup is important!
|
# authentication at shutdown time - that cleanup is important!
|
||||||
while 1:
|
while 1:
|
||||||
try:
|
line = _read_next_string_line()
|
||||||
line = stdin.readline(128)
|
if not line:
|
||||||
except IOError as e:
|
|
||||||
debug3('read from stdin failed: %s' % (e,))
|
|
||||||
return
|
return
|
||||||
if line.startswith('HOST '):
|
if line.startswith('HOST '):
|
||||||
(name, ip) = line[5:].strip().split(',', 1)
|
(name, ip) = line[5:].split(',', 1)
|
||||||
hostmap[name] = ip
|
hostmap[name] = ip
|
||||||
debug2('setting up /etc/hosts.')
|
debug2('setting up /etc/hosts.')
|
||||||
rewrite_etc_hosts(hostmap, port_v6 or port_v4)
|
rewrite_etc_hosts(hostmap, port_v6 or port_v4)
|
||||||
|
@ -284,7 +284,7 @@ class Method(BaseMethod):
|
|||||||
return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]"))
|
return ipaddress.ip_address(local_addr[:-len(port_suffix)].strip("[]"))
|
||||||
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
|
raise Fatal("Could not find listening address for {}/{}".format(port, proto))
|
||||||
|
|
||||||
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, tmark):
|
def setup_firewall(self, port, dnsport, nslist, family, subnets, udp, user, group, tmark):
|
||||||
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
log(f"{port=}, {dnsport=}, {nslist=}, {family=}, {subnets=}, {udp=}, {user=}, {tmark=}")
|
||||||
|
|
||||||
if nslist or user or udp:
|
if nslist or user or udp:
|
||||||
@ -345,7 +345,7 @@ class Method(BaseMethod):
|
|||||||
if not ev.wait(5): # at most 5 sec
|
if not ev.wait(5): # at most 5 sec
|
||||||
raise Fatal("timeout in wait_for_firewall_ready()")
|
raise Fatal("timeout in wait_for_firewall_ready()")
|
||||||
|
|
||||||
def restore_firewall(self, port, family, udp, user):
|
def restore_firewall(self, port, family, udp, user, group):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_supported_features(self):
|
def get_supported_features(self):
|
||||||
|
@ -260,7 +260,7 @@ def connect(ssh_cmd, rhostport, python, stderr, add_cmd_delimiter, options):
|
|||||||
|
|
||||||
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
|
threading.Thread(target=stream_stdout_to_sock, name='stream_stdout_to_sock', daemon=True).start()
|
||||||
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
|
threading.Thread(target=stream_sock_to_stdin, name='stream_sock_to_stdin', daemon=True).start()
|
||||||
return s2
|
return s2.makefile("rb", buffering=0), s2.makefile("wb", buffering=0)
|
||||||
|
|
||||||
# https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows
|
# https://stackoverflow.com/questions/48671215/howto-workaround-of-close-fds-true-and-redirect-stdout-stderr-on-windows
|
||||||
close_fds = False if sys.platform == 'win32' else True
|
close_fds = False if sys.platform == 'win32' else True
|
||||||
|
@ -10,7 +10,7 @@ import sshuttle.firewall
|
|||||||
|
|
||||||
|
|
||||||
def setup_daemon():
|
def setup_daemon():
|
||||||
stdin = io.StringIO(u"""ROUTES
|
stdin = io.BytesIO(u"""ROUTES
|
||||||
{inet},24,0,1.2.3.0,8000,9000
|
{inet},24,0,1.2.3.0,8000,9000
|
||||||
{inet},32,1,1.2.3.66,8080,8080
|
{inet},32,1,1.2.3.66,8080,8080
|
||||||
{inet6},64,0,2404:6800:4004:80c::,0,0
|
{inet6},64,0,2404:6800:4004:80c::,0,0
|
||||||
@ -127,9 +127,9 @@ def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert stdout.mock_calls == [
|
assert stdout.mock_calls == [
|
||||||
call.write('READY test\n'),
|
call.write(b'READY test\n'),
|
||||||
call.flush(),
|
call.flush(),
|
||||||
call.write('STARTED\n'),
|
call.write(b'STARTED\n'),
|
||||||
call.flush()
|
call.flush()
|
||||||
]
|
]
|
||||||
assert mock_setup_daemon.mock_calls == [call()]
|
assert mock_setup_daemon.mock_calls == [call()]
|
||||||
|
Loading…
Reference in New Issue
Block a user