diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 5d6d315a..4ee0d830 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -18,7 +18,7 @@ os_name = platform.system() modules_to_check = { "torch": ("1.11.0", "1.13.1", "2.0.0"), "torchvision": ("0.12.0", "0.14.1", "0.15.1"), - "sdkit": "1.0.144", + "sdkit": "1.0.146", "stable-diffusion-sdkit": "2.1.4", "rich": "12.6.0", "uvicorn": "0.19.0", diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 921886f1..e181f9b8 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -32,6 +32,8 @@ logging.basicConfig( SD_DIR = os.getcwd() +ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, "..")) + SD_UI_DIR = os.getenv("SD_UI_PATH", None) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) diff --git a/ui/easydiffusion/package_manager.py b/ui/easydiffusion/package_manager.py new file mode 100644 index 00000000..2ec83893 --- /dev/null +++ b/ui/easydiffusion/package_manager.py @@ -0,0 +1,93 @@ +import sys +import os +import platform +from importlib.metadata import version as pkg_version + +from sdkit.utils import log + +from easydiffusion import app + +# future home of scripts/check_modules.py + +manifest = { + "tensorrt": { + "install": ["nvidia-cudnn", "tensorrt-libs", "tensorrt"], + "uninstall": ["tensorrt"], + # TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error) + } +} +installing = [] + +# remove this once TRT releases on pypi +if platform.system() == "Windows": + trt_dir = os.path.join(app.ROOT_DIR, "tensorrt") + if os.path.exists(trt_dir): + files = os.listdir(trt_dir) + + packages = manifest["tensorrt"]["install"] + packages = tuple(p.replace("-", "_") for p in packages) + + wheels = [] + for p in packages: + f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None) + if f: + wheels.append(os.path.join(trt_dir, f)) + + manifest["tensorrt"]["install"] = wheels + + +def get_installed_packages() -> list: + return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)} + + +def is_installed(module_name) -> bool: + return version(module_name) is not None + + +def install(module_name): + if is_installed(module_name): + log.info(f"{module_name} has already been installed!") + return + if module_name in installing: + log.info(f"{module_name} is already installing!") + return + + if module_name not in manifest: + raise RuntimeError(f"Can't install unknown package: {module_name}!") + + commands = manifest[module_name]["install"] + commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands] + + installing.append(module_name) + + try: + for cmd in commands: + print(">", cmd) + if os.system(cmd) != 0: + raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.") + finally: + installing.remove(module_name) + + +def uninstall(module_name): + if not is_installed(module_name): + log.info(f"{module_name} hasn't been installed!") + return + + if module_name not in manifest: + raise RuntimeError(f"Can't uninstall unknown package: {module_name}!") + + commands = manifest[module_name]["uninstall"] + commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands] + + for cmd in commands: + print(">", cmd) + if os.system(cmd) != 0: + raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.") + + +def version(module_name: str) -> str: + try: + return pkg_version(module_name) + except: + return None diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 0f1890c3..a8f848fd 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -8,7 +8,7 @@ import os import traceback from typing import List, Union -from easydiffusion import app, model_manager, task_manager +from easydiffusion import app, model_manager, task_manager, package_manager from easydiffusion.tasks import RenderTask, FilterTask from easydiffusion.types import ( GenerateImageRequest, @@ -135,6 +135,10 @@ def init(): def stop_cloudflare_tunnel(req: dict): return stop_cloudflare_tunnel_internal(req) + @server_api.post("/package/{package_name:str}") + def modify_package(package_name: str, req: dict): + return modify_package_internal(package_name, req) + @server_api.get("/") def read_root(): return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS) @@ -226,16 +230,24 @@ def ping_internal(session_id: str = None): if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) raise HTTPException(status_code=500, detail="Render thread is dead.") + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + # Alive response = {"status": str(task_manager.current_state)} + if session_id: session = task_manager.get_cached_session(session_id, update_ttl=True) response["tasks"] = {id(t): t.status for t in session.tasks} + response["devices"] = task_manager.get_devices() + response["packages_installed"] = package_manager.get_installed_packages() + response["packages_installing"] = package_manager.installing + if cloudflare.address != None: response["cloudflare"] = cloudflare.address + return JSONResponse(response, headers=NOCACHE_HEADERS) @@ -423,3 +435,19 @@ def stop_cloudflare_tunnel_internal(req: dict): log.error(str(e)) log.error(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e)) + + +def modify_package_internal(package_name: str, req: dict): + try: + cmd = req["command"] + if cmd not in ("install", "uninstall"): + raise RuntimeError(f"Unknown command: {cmd}") + + cmd = getattr(package_manager, cmd) + cmd(package_name) + + return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) + except Exception as e: + log.error(str(e)) + log.error(traceback.format_exc()) + return HTTPException(status_code=500, detail=str(e)) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 27b53b6f..699b4494 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -373,6 +373,12 @@ def get_devices(): finally: manager_lock.release() + # temp until TRT releases + import os + from easydiffusion import app + + devices["enable_trt"] = os.path.exists(os.path.join(app.ROOT_DIR, "tensorrt")) + return devices diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index 42bcc76a..bbc36aa5 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -60,7 +60,11 @@ class RenderTask(Task): model_manager.resolve_model_paths(self.models_data) models_to_force_reload = [] - if runtime.set_vram_optimizations(context) or self.has_clip_skip_changed(context): + if ( + runtime.set_vram_optimizations(context) + or self.has_param_changed(context, "clip_skip") + or self.has_param_changed(context, "convert_to_tensorrt") + ): models_to_force_reload.append("stable-diffusion") model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload) @@ -78,13 +82,15 @@ class RenderTask(Task): step_callback, ) - def has_clip_skip_changed(self, context): + def has_param_changed(self, context, param_name): if not context.test_diffusers: return False + if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: + return True model = context.models["stable-diffusion"] - new_clip_skip = self.models_data.model_params.get("stable-diffusion", {}).get("clip_skip", False) - return model["clip_skip"] != new_clip_skip + new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) + return model["params"].get(param_name) != new_val def make_images( diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index a37da9ef..b1d55f5a 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -217,7 +217,10 @@ def convert_legacy_render_req_to_new(old_req: dict): # move the model params if model_paths["stable-diffusion"]: - model_params["stable-diffusion"] = {"clip_skip": bool(old_req.get("clip_skip", False))} + model_params["stable-diffusion"] = { + "clip_skip": bool(old_req.get("clip_skip", False)), + "convert_to_tensorrt": bool(old_req.get("convert_to_tensorrt", False)), + } # move the filter params if model_paths["realesrgan"]: diff --git a/ui/index.html b/ui/index.html index 776eb627..2d47433e 100644 --- a/ui/index.html +++ b/ui/index.html @@ -146,6 +146,14 @@ Click to learn more about custom models +