mirror of
https://github.com/sshuttle/sshuttle.git
synced 2024-11-08 00:54:06 +01:00
Gracefully exit if firewall process receives Ctrl+C/SIGINT.
Typically sshuttle exits by having the main sshuttle client process terminated. This closes file descriptors which the firewall process then sees and uses as a cue to cleanup the firewall rules. The firewall process ignored SIGINT/SIGTERM signals and used setsid() to prevent Ctrl+C from sending signals to the firewall process. This patch makes the firewall process accept SIGINT/SIGTERM signals and then in turn sends a SIGINT signal to the main sshuttle client process which then triggers a regular shutdown as described above. This allows a user to manually send a SIGINT/SIGTERM to either sshuttle process and have it exit gracefully. It also is needed if setsid() fails (known to occur if sudo's use_pty option is used) and then the Ctrl+C SIGINT signal goes to the firewall process. The PID of the sshuttle client process is sent to the firewall process. Using os.getppid() in the firewall process doesn't correctly return the sshuttle client PID.
This commit is contained in:
parent
ae1faa7fa1
commit
ae8af71886
@ -298,8 +298,8 @@ class FirewallClient:
|
||||
else:
|
||||
user = b'%d' % self.user
|
||||
|
||||
self.pfile.write(b'GO %d %s %s\n' %
|
||||
(udp, user, bytes(self.tmark, 'ascii')))
|
||||
self.pfile.write(b'GO %d %s %s %d\n' %
|
||||
(udp, user, bytes(self.tmark, 'ascii'), os.getpid()))
|
||||
self.pfile.flush()
|
||||
|
||||
line = self.pfile.readline()
|
||||
|
@ -13,7 +13,7 @@ from sshuttle.helpers import debug1, debug2, Fatal
|
||||
from sshuttle.methods import get_auto_method, get_method
|
||||
|
||||
HOSTSFILE = '/etc/hosts'
|
||||
|
||||
sshuttle_pid = None
|
||||
|
||||
def rewrite_etc_hosts(hostmap, port):
|
||||
BAKFILE = '%s.sbak' % HOSTSFILE
|
||||
@ -55,6 +55,23 @@ def restore_etc_hosts(hostmap, port):
|
||||
debug2('undoing /etc/hosts changes.')
|
||||
rewrite_etc_hosts({}, port)
|
||||
|
||||
def firewall_exit(signum, frame):
|
||||
# The typical sshuttle exit is that the main sshuttle process
|
||||
# exits, closes file descriptors it uses, and the firewall process
|
||||
# notices that it can't read from stdin anymore and exits
|
||||
# (cleaning up firewall rules).
|
||||
#
|
||||
# However, in some cases, Ctrl+C might get sent to the firewall
|
||||
# process. This might caused if someone manually tries to kill the
|
||||
# firewall process, or if sshuttle was started using sudo's use_pty option
|
||||
# and they try to exit by pressing Ctrl+C. Here, we forward the
|
||||
# Ctrl+C/SIGINT to the main sshuttle process which should trigger
|
||||
# the typical exit process as described above.
|
||||
global sshuttle_pid
|
||||
if sshuttle_pid:
|
||||
debug1("Relaying SIGINT to sshuttle process %d\n" % sshuttle_pid)
|
||||
os.kill(sshuttle_pid, signal.SIGINT)
|
||||
|
||||
|
||||
# Isolate function that needs to be replaced for tests
|
||||
def setup_daemon():
|
||||
@ -65,8 +82,8 @@ def setup_daemon():
|
||||
# disappears; we still have to clean up.
|
||||
signal.signal(signal.SIGHUP, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
signal.signal(signal.SIGTERM, firewall_exit)
|
||||
signal.signal(signal.SIGINT, firewall_exit)
|
||||
|
||||
# ctrl-c shouldn't be passed along to me. When the main sshuttle dies,
|
||||
# I'll die automatically.
|
||||
@ -230,12 +247,14 @@ def main(method_name, syslog):
|
||||
raise Fatal('expected GO but got %r' % line)
|
||||
|
||||
_, _, args = line.partition(" ")
|
||||
udp, user, tmark = args.strip().split(" ", 2)
|
||||
global sshuttle_pid
|
||||
udp, user, tmark, sshuttle_pid = args.strip().split(" ", 3)
|
||||
udp = bool(int(udp))
|
||||
sshuttle_pid = int(sshuttle_pid)
|
||||
if user == '-':
|
||||
user = None
|
||||
debug2('Got udp: %r, user: %r, tmark: %s' %
|
||||
(udp, user, tmark))
|
||||
debug2('Got udp: %r, user: %r, tmark: %s, sshuttle_pid: %d' %
|
||||
(udp, user, tmark, sshuttle_pid))
|
||||
|
||||
subnets_v6 = [i for i in subnets if i[0] == socket.AF_INET6]
|
||||
nslist_v6 = [i for i in nslist if i[0] == socket.AF_INET6]
|
||||
|
Loading…
Reference in New Issue
Block a user