Merge ab3e20fbd67af8cb752ebd83d1fd96e9f0df8082 into 5c7625c425f12330eec34870511c31667d9fe349

This commit is contained in:
patriceac 2025-04-04 21:06:16 +00:00 committed by GitHub
commit a53b84926e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,9 @@ import mimetypes
import os
import traceback
from typing import List, Union
import platform
import subprocess
import ctypes
from easydiffusion import app, model_manager, task_manager, package_manager
from easydiffusion.tasks import RenderTask, FilterTask
@ -30,6 +33,11 @@ from pydantic import BaseModel, Extra
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pycloudflared import try_cloudflare
# Constants for preventing sleep
ES_CONTINUOUS = 0x80000000
ES_SYSTEM_REQUIRED = 0x00000001
ES_DISPLAY_REQUIRED = 0x00000002
log.info(f"started in {app.SD_DIR}")
log.info(f"started at {datetime.datetime.now():%x %X}")
@ -42,6 +50,34 @@ NOCACHE_HEADERS = {
}
PROTECTED_CONFIG_KEYS = ("block_nsfw",) # can't change these via the HTTP API
# Constants for preventing sleep on Windows
ES_CONTINUOUS = 0x80000000
ES_SYSTEM_REQUIRED = 0x00000001
ES_DISPLAY_REQUIRED = 0x00000002
class PreventSleep:
def __init__(self):
self.os_name = platform.system()
self.inhibit_process = None
def prevent_sleep(self):
if self.os_name == "Windows":
ctypes.windll.kernel32.SetThreadExecutionState(
ES_CONTINUOUS | ES_SYSTEM_REQUIRED | ES_DISPLAY_REQUIRED)
elif self.os_name == "Darwin":
self.inhibit_process = subprocess.Popen(['caffeinate'])
elif self.os_name == "Linux":
self.inhibit_process = subprocess.Popen(['systemd-inhibit', '--what=handle-lid-switch', '--why="Prevent sleep during image generation"', 'bash', '-c', 'sleep infinity'])
else:
raise NotImplementedError("Unsupported OS")
def allow_sleep(self):
if self.os_name == "Windows":
ctypes.windll.kernel32.SetThreadExecutionState(ES_CONTINUOUS)
elif self.inhibit_process:
self.inhibit_process.terminate()
self.inhibit_process = None
class NoCacheStaticFiles(StaticFiles):
def __init__(self, directory: str):
@ -284,7 +320,9 @@ def ping_internal(session_id: str = None):
def render_internal(req: dict):
prevent_sleep_obj = PreventSleep()
try:
prevent_sleep_obj.prevent_sleep()
req = convert_legacy_render_req_to_new(req)
# separate out the request data into rendering and task-specific data
@ -316,6 +354,8 @@ def render_internal(req: dict):
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
finally:
prevent_sleep_obj.allow_sleep()
def filter_internal(req: dict):