diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index a251ede6..ef74be67 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -7,6 +7,9 @@ import mimetypes import os import traceback from typing import List, Union +import platform +import subprocess +import ctypes from easydiffusion import app, model_manager, task_manager, package_manager from easydiffusion.tasks import RenderTask, FilterTask @@ -28,6 +31,13 @@ from pydantic import BaseModel, Extra from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pycloudflared import try_cloudflare +import ctypes + +# Constants for preventing sleep +ES_CONTINUOUS = 0x80000000 +ES_SYSTEM_REQUIRED = 0x00000001 +ES_DISPLAY_REQUIRED = 0x00000002 + log.info(f"started in {app.SD_DIR}") log.info(f"started at {datetime.datetime.now():%x %X}") @@ -40,6 +50,34 @@ NOCACHE_HEADERS = { } PROTECTED_CONFIG_KEYS = ("block_nsfw",) # can't change these via the HTTP API +# Constants for preventing sleep on Windows +ES_CONTINUOUS = 0x80000000 +ES_SYSTEM_REQUIRED = 0x00000001 +ES_DISPLAY_REQUIRED = 0x00000002 + +class PreventSleep: + def __init__(self): + self.os_name = platform.system() + self.inhibit_process = None + + def prevent_sleep(self): + if self.os_name == "Windows": + ctypes.windll.kernel32.SetThreadExecutionState( + ES_CONTINUOUS | ES_SYSTEM_REQUIRED | ES_DISPLAY_REQUIRED) + elif self.os_name == "Darwin": + self.inhibit_process = subprocess.Popen(['caffeinate']) + elif self.os_name == "Linux": + self.inhibit_process = subprocess.Popen(['systemd-inhibit', '--what=handle-lid-switch', '--why="Prevent sleep during image generation"', 'bash', '-c', 'sleep infinity']) + else: + raise NotImplementedError("Unsupported OS") + + def allow_sleep(self): + if self.os_name == "Windows": + ctypes.windll.kernel32.SetThreadExecutionState(ES_CONTINUOUS) + elif self.inhibit_process: + self.inhibit_process.terminate() + self.inhibit_process = None + class NoCacheStaticFiles(StaticFiles): def __init__(self, directory: str): @@ -267,7 +305,9 @@ def ping_internal(session_id: str = None): def render_internal(req: dict): + prevent_sleep_obj = PreventSleep() try: + prevent_sleep_obj.prevent_sleep() req = convert_legacy_render_req_to_new(req) # separate out the request data into rendering and task-specific data @@ -299,6 +339,8 @@ def render_internal(req: dict): except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + finally: + prevent_sleep_obj.allow_sleep() def filter_internal(req: dict):