diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index dac65852..3cebab24 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -6,6 +6,7 @@ from threading import local import psutil import time import shutil +import atexit from easydiffusion.app import ROOT_DIR, getConfig from easydiffusion.model_manager import get_model_dirs @@ -62,6 +63,7 @@ WEBUI_PATCHES = [ "forge_model_crash_recovery.patch", "forge_api_refresh_text_encoders.patch", "forge_loader_force_gc.patch", + "forge_monitor_parent_process.patch", ] backend_process = None @@ -188,6 +190,10 @@ def start_backend(): print("starting", cmd, WEBUI_DIR) backend_process = run_in_conda([cmd], cwd=WEBUI_DIR, env=env, wait=False, output_prefix="[WebUI] ") + # atexit.register isn't 100% reliable, that's why we also use `forge_monitor_parent_process.patch` + # which causes Forge to kill itself if the parent pid passed to it is no longer valid. + atexit.register(backend_process.terminate) + restart_if_dead_thread = threading.Thread(target=restart_if_webui_dies_after_starting) restart_if_dead_thread.start() @@ -366,7 +372,7 @@ def get_env(): "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], "COMMANDLINE_ARGS": [ - f'--api --models-dir "{models_dir}" {model_path_args} --skip-torch-cuda-test --disable-gpu-warning --port {impl.WEBUI_PORT}' + f'--parent-pid {os.getpid()} --api --models-dir "{models_dir}" {model_path_args} --skip-torch-cuda-test --disable-gpu-warning --port {impl.WEBUI_PORT}' ], "SKIP_VENV": ["1"], "SD_WEBUI_RESTARTING": ["1"], diff --git a/ui/easydiffusion/backends/webui/forge_monitor_parent_process.patch b/ui/easydiffusion/backends/webui/forge_monitor_parent_process.patch new file mode 100644 index 00000000..3578768f --- /dev/null +++ b/ui/easydiffusion/backends/webui/forge_monitor_parent_process.patch @@ -0,0 +1,85 @@ +diff --git a/launch.py b/launch.py +index c0568c7b..3919f7dd 100644 +--- a/launch.py ++++ b/launch.py +@@ -2,6 +2,7 @@ + # faulthandler.enable() + + from modules import launch_utils ++from modules import parent_process_monitor + + args = launch_utils.args + python = launch_utils.python +@@ -28,6 +29,10 @@ start = launch_utils.start + + + def main(): ++ if args.parent_pid != -1: ++ print(f"Monitoring parent process for termination. Parent PID: {args.parent_pid}") ++ parent_process_monitor.start_monitor_thread(args.parent_pid) ++ + if args.dump_sysinfo: + filename = launch_utils.dump_sysinfo() + +diff --git a/modules/cmd_args.py b/modules/cmd_args.py +index fcd8a50f..7f684bec 100644 +--- a/modules/cmd_args.py ++++ b/modules/cmd_args.py +@@ -148,3 +148,6 @@ parser.add_argument( + help="Path to directory with annotator model directories", + default=None, + ) ++ ++# Easy Diffusion arguments ++parser.add_argument("--parent-pid", type=int, default=-1, help='parent process id, if running webui as a sub-process') +diff --git a/modules/parent_process_monitor.py b/modules/parent_process_monitor.py +new file mode 100644 +index 00000000..cc3e2049 +--- /dev/null ++++ b/modules/parent_process_monitor.py +@@ -0,0 +1,45 @@ ++# monitors and kills itself when the parent process dies. required when running Forge as a sub-process. ++# modified version of https://stackoverflow.com/a/23587108 ++ ++import sys ++import os ++import threading ++import platform ++import time ++ ++ ++def _monitor_parent_posix(parent_pid): ++ print(f"Monitoring parent pid: {parent_pid}") ++ while True: ++ if os.getppid() != parent_pid: ++ os._exit(0) ++ time.sleep(1) ++ ++ ++def _monitor_parent_windows(parent_pid): ++ from ctypes import WinDLL, WinError ++ from ctypes.wintypes import DWORD, BOOL, HANDLE ++ ++ SYNCHRONIZE = 0x00100000 # Magic value from http://msdn.microsoft.com/en-us/library/ms684880.aspx ++ kernel32 = WinDLL("kernel32.dll") ++ kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) ++ kernel32.OpenProcess.restype = HANDLE ++ ++ handle = kernel32.OpenProcess(SYNCHRONIZE, False, parent_pid) ++ if not handle: ++ raise WinError() ++ ++ # Wait until parent exits ++ from ctypes import windll ++ ++ print(f"Monitoring parent pid: {parent_pid}") ++ windll.kernel32.WaitForSingleObject(handle, -1) ++ os._exit(0) ++ ++ ++def start_monitor_thread(parent_pid): ++ if platform.system() == "Windows": ++ t = threading.Thread(target=_monitor_parent_windows, args=(parent_pid,), daemon=True) ++ else: ++ t = threading.Thread(target=_monitor_parent_posix, args=(parent_pid,), daemon=True) ++ t.start()