Merge pull request #712 from skuhl/sudo-use-pty-fix

Fix sshuttle when using sudo's use_pty option.
This commit is contained in:
Brian May 2022-01-10 10:03:55 +11:00 committed by GitHub
commit 0ccd243a65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 37 deletions

View File

@ -18,7 +18,7 @@ while 1:
name = name.decode("ASCII") name = name.decode("ASCII")
nbytes = int(sys.stdin.readline()) nbytes = int(sys.stdin.readline())
if verbosity >= 2: if verbosity >= 2:
sys.stderr.write(' s: assembling %r (%d bytes)\n' sys.stderr.write(' s: assembling %r (%d bytes)\r\n'
% (name, nbytes)) % (name, nbytes))
content = z.decompress(sys.stdin.read(nbytes)) content = z.decompress(sys.stdin.read(nbytes))

View File

@ -298,8 +298,8 @@ class FirewallClient:
else: else:
user = b'%d' % self.user user = b'%d' % self.user
self.pfile.write(b'GO %d %s %s\n' % self.pfile.write(b'GO %d %s %s %d\n' %
(udp, user, bytes(self.tmark, 'ascii'))) (udp, user, bytes(self.tmark, 'ascii'), os.getpid()))
self.pfile.flush() self.pfile.flush()
line = self.pfile.readline() line = self.pfile.readline()

View File

@ -13,6 +13,7 @@ from sshuttle.helpers import log, debug1, debug2, Fatal
from sshuttle.methods import get_auto_method, get_method from sshuttle.methods import get_auto_method, get_method
HOSTSFILE = '/etc/hosts' HOSTSFILE = '/etc/hosts'
sshuttle_pid = None
def rewrite_etc_hosts(hostmap, port): def rewrite_etc_hosts(hostmap, port):
@ -56,6 +57,24 @@ def restore_etc_hosts(hostmap, port):
rewrite_etc_hosts({}, port) rewrite_etc_hosts({}, port)
def firewall_exit(signum, frame):
# The typical sshuttle exit is that the main sshuttle process
# exits, closes file descriptors it uses, and the firewall process
# notices that it can't read from stdin anymore and exits
# (cleaning up firewall rules).
#
# However, in some cases, Ctrl+C might get sent to the firewall
# process. This might caused if someone manually tries to kill the
# firewall process, or if sshuttle was started using sudo's use_pty option
# and they try to exit by pressing Ctrl+C. Here, we forward the
# Ctrl+C/SIGINT to the main sshuttle process which should trigger
# the typical exit process as described above.
global sshuttle_pid
if sshuttle_pid:
debug1("Relaying SIGINT to sshuttle process %d\n" % sshuttle_pid)
os.kill(sshuttle_pid, signal.SIGINT)
# Isolate function that needs to be replaced for tests # Isolate function that needs to be replaced for tests
def setup_daemon(): def setup_daemon():
if os.getuid() != 0: if os.getuid() != 0:
@ -65,19 +84,20 @@ def setup_daemon():
# disappears; we still have to clean up. # disappears; we still have to clean up.
signal.signal(signal.SIGHUP, signal.SIG_IGN) signal.signal(signal.SIGHUP, signal.SIG_IGN)
signal.signal(signal.SIGPIPE, signal.SIG_IGN) signal.signal(signal.SIGPIPE, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN) signal.signal(signal.SIGTERM, firewall_exit)
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, firewall_exit)
# ctrl-c shouldn't be passed along to me. When the main sshuttle dies, # Calling setsid() here isn't strictly necessary. However, it forces
# I'll die automatically. # Ctrl+C to get sent to the main sshuttle process instead of to
# the firewall process---which is our preferred way to shutdown.
# Nonetheless, if the firewall process receives a SIGTERM/SIGINT
# signal, it will relay a SIGINT to the main sshuttle process
# automatically.
try: try:
os.setsid() os.setsid()
except OSError: except OSError:
raise Fatal("setsid() failed. This may occur if you are using sudo's " # setsid() fails if sudo is configured with the use_pty option.
"use_pty option. sshuttle does not currently work with " pass
"this option. An imperfect workaround: Run the sshuttle "
"command with sudo instead of running it as a regular "
"user and entering the sudo password when prompted.")
# because of limitations of the 'su' command, the *real* stdin/stdout # because of limitations of the 'su' command, the *real* stdin/stdout
# are both attached to stdout initially. Clone stdout into stdin so we # are both attached to stdout initially. Clone stdout into stdin so we
@ -238,12 +258,14 @@ def main(method_name, syslog):
raise Fatal('expected GO but got %r' % line) raise Fatal('expected GO but got %r' % line)
_, _, args = line.partition(" ") _, _, args = line.partition(" ")
udp, user, tmark = args.strip().split(" ", 2) global sshuttle_pid
udp, user, tmark, sshuttle_pid = args.strip().split(" ", 3)
udp = bool(int(udp)) udp = bool(int(udp))
sshuttle_pid = int(sshuttle_pid)
if user == '-': if user == '-':
user = None user = None
debug2('Got udp: %r, user: %r, tmark: %s' % debug2('Got udp: %r, user: %r, tmark: %s, sshuttle_pid: %d' %
(udp, user, tmark)) (udp, user, tmark, sshuttle_pid))
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6] subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6] nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]

View File

@ -18,15 +18,19 @@ def log(s):
# Put newline at end of string if line doesn't have one. # Put newline at end of string if line doesn't have one.
if not s.endswith("\n"): if not s.endswith("\n"):
s = s+"\n" s = s+"\n"
# Allow multi-line messages
if s.find("\n") != -1: prefix = logprefix
prefix = logprefix s = s.rstrip("\n")
s = s.rstrip("\n") for line in s.split("\n"):
for line in s.split("\n"): # We output with \r\n instead of \n because when we use
sys.stderr.write(prefix + line + "\n") # sudo with the use_pty option, the firewall process, the
prefix = " " # other processes printing to the terminal will have the
else: # \n move to the next line, but they will fail to reset
sys.stderr.write(logprefix + s) # cursor to the beginning of the line. Printing output
# with \r\n endings fixes that problem and does not appear
# to cause problems elsewhere.
sys.stderr.write(prefix + line + "\r\n")
prefix = " "
sys.stderr.flush() sys.stderr.flush()
except IOError: except IOError:
# this could happen if stderr gets forcibly disconnected, eg. because # this could happen if stderr gets forcibly disconnected, eg. because

View File

@ -15,7 +15,7 @@ NSLIST
{inet},1.2.3.33 {inet},1.2.3.33
{inet6},2404:6800:4004:80c::33 {inet6},2404:6800:4004:80c::33
PORTS 1024,1025,1026,1027 PORTS 1024,1025,1026,1027
GO 1 - 0x01 GO 1 - 0x01 12345
HOST 1.2.3.3,existing HOST 1.2.3.3,existing
""".format(inet=AF_INET, inet6=AF_INET6)) """.format(inet=AF_INET, inet6=AF_INET6))
stdout = Mock() stdout = Mock()

View File

@ -24,19 +24,19 @@ def test_log(mock_stderr, mock_stdout):
call.flush(), call.flush(),
] ]
assert mock_stderr.mock_calls == [ assert mock_stderr.mock_calls == [
call.write('prefix: message\n'), call.write('prefix: message\r\n'),
call.flush(), call.flush(),
call.write('prefix: abc\n'), call.write('prefix: abc\r\n'),
call.flush(), call.flush(),
call.write('prefix: message 1\n'), call.write('prefix: message 1\r\n'),
call.flush(), call.flush(),
call.write('prefix: message 2\n'), call.write('prefix: message 2\r\n'),
call.write(' line2\n'), call.write(' line2\r\n'),
call.write(' line3\n'), call.write(' line3\r\n'),
call.flush(), call.flush(),
call.write('prefix: message 3\n'), call.write('prefix: message 3\r\n'),
call.write(' line2\n'), call.write(' line2\r\n'),
call.write(' line3\n'), call.write(' line3\r\n'),
call.flush(), call.flush(),
] ]
@ -51,7 +51,7 @@ def test_debug1(mock_stderr, mock_stdout):
call.flush(), call.flush(),
] ]
assert mock_stderr.mock_calls == [ assert mock_stderr.mock_calls == [
call.write('prefix: message\n'), call.write('prefix: message\r\n'),
call.flush(), call.flush(),
] ]
@ -76,7 +76,7 @@ def test_debug2(mock_stderr, mock_stdout):
call.flush(), call.flush(),
] ]
assert mock_stderr.mock_calls == [ assert mock_stderr.mock_calls == [
call.write('prefix: message\n'), call.write('prefix: message\r\n'),
call.flush(), call.flush(),
] ]
@ -101,7 +101,7 @@ def test_debug3(mock_stderr, mock_stdout):
call.flush(), call.flush(),
] ]
assert mock_stderr.mock_calls == [ assert mock_stderr.mock_calls == [
call.write('prefix: message\n'), call.write('prefix: message\r\n'),
call.flush(), call.flush(),
] ]