From 45f350239e4f5edd72f3b71debb607d9add8900b Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 23 Jul 2024 17:56:35 -0500 Subject: [PATCH 01/64] Don't break if we can't write a file --- ui/easydiffusion/model_manager.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index f6e1e2d0..8b234884 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -282,9 +282,11 @@ def make_model_folders(): help_file_name = f"Place your {model_type} model files here.txt" help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' - - with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: - f.write(help_file_contents) + try: + with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: + f.write(help_file_contents) + except: + pass def is_malicious_model(file_path): From 89f5e07619a7fadfb4f1168e163e96b9ecdf6f82 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 24 Jul 2024 09:47:35 +0530 Subject: [PATCH 02/64] Log the exception while trying to create an extension info file --- ui/easydiffusion/model_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 8b234884..e6ba997a 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -285,8 +285,8 @@ def make_model_folders(): try: with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: f.write(help_file_contents) - except: - pass + except Exception as e: + log.exception(e) def is_malicious_model(file_path): From b7d46be530db3320108408cb4fa3e535947a9d03 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 9 Sep 2024 18:48:38 +0530 Subject: [PATCH 03/64] Use SD 1.4 instead of 1.5 during installation --- NSIS/sdui.nsi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NSIS/sdui.nsi b/NSIS/sdui.nsi index 92bec886..2995328c 100644 --- a/NSIS/sdui.nsi +++ b/NSIS/sdui.nsi @@ -235,8 +235,8 @@ Section "MainSection" SEC01 CreateDirectory "$SMPROGRAMS\Easy Diffusion" CreateShortCut "$SMPROGRAMS\Easy Diffusion\Easy Diffusion.lnk" "$INSTDIR\Start Stable Diffusion UI.cmd" "" "$INSTDIR\installer_files\cyborg_flower_girl.ico" - DetailPrint 'Downloading the Stable Diffusion 1.5 model...' - NScurl::http get "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/sd-v1-5.safetensors" "$INSTDIR\models\stable-diffusion\sd-v1-5.safetensors" /CANCEL /INSIST /END + DetailPrint 'Downloading the Stable Diffusion 1.4 model...' + NScurl::http get "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/sd-v1-4.safetensors" "$INSTDIR\models\stable-diffusion\sd-v1-4.safetensors" /CANCEL /INSIST /END DetailPrint 'Downloading the GFPGAN model...' NScurl::http get "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" "$INSTDIR\models\gfpgan\GFPGANv1.4.pth" /CANCEL /INSIST /END From 6559c41b2e05e0ac1b3d9e6f2a8952704156c47e Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 24 Sep 2024 17:27:50 +0530 Subject: [PATCH 04/64] minor log import change --- ui/easydiffusion/package_manager.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ui/easydiffusion/package_manager.py b/ui/easydiffusion/package_manager.py index c28a58a1..9df7f1fd 100644 --- a/ui/easydiffusion/package_manager.py +++ b/ui/easydiffusion/package_manager.py @@ -3,8 +3,6 @@ import os import platform from importlib.metadata import version as pkg_version -from sdkit.utils import log - from easydiffusion import app # future home of scripts/check_modules.py @@ -50,6 +48,8 @@ def is_installed(module_name) -> bool: def install(module_name): + from easydiffusion.utils import log + if is_installed(module_name): log.info(f"{module_name} has already been installed!") return @@ -79,6 +79,8 @@ def install(module_name): def uninstall(module_name): + from easydiffusion.utils import log + if not is_installed(module_name): log.info(f"{module_name} hasn't been installed!") return From a0de0b5814cb8cc3f576b980955a5ebc9e34f7e2 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 24 Sep 2024 17:40:00 +0530 Subject: [PATCH 05/64] Minor refactoring of SD dir variables --- ui/easydiffusion/app.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 43d0e3c4..a1f607ef 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -36,10 +36,10 @@ ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, "..")) SD_UI_DIR = os.getenv("SD_UI_PATH", None) -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) -BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket")) +CONFIG_DIR = os.path.abspath(os.path.join(ROOT_DIR, "scripts")) +BUCKET_DIR = os.path.abspath(os.path.join(ROOT_DIR, "bucket")) -USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) +USER_PLUGINS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "plugins")) CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui") @@ -77,7 +77,7 @@ IMAGE_EXTENSIONS = [ ".avif", ".svg", ] -CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers")) +CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "modifiers")) CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [ ".portrait", "_portrait", @@ -91,7 +91,7 @@ CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [ "-landscape", ] -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) +MODELS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "models")) def init(): @@ -105,7 +105,7 @@ def init(): config = getConfig() config_models_dir = config.get("models_dir", None) - if (config_models_dir is not None and config_models_dir != ""): + if config_models_dir is not None and config_models_dir != "": MODELS_DIR = config_models_dir From 2eb0c9106ae47f871524bb73464eedb609ab74cf Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 24 Sep 2024 17:43:34 +0530 Subject: [PATCH 06/64] Temporarily disable the auto-selection of the appropriate controlnet model --- ui/media/js/main.js | 60 ++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/ui/media/js/main.js b/ui/media/js/main.js index bff87b3b..1cbd3288 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -1845,36 +1845,36 @@ controlImagePreview.addEventListener("load", onControlnetModelChange) controlImagePreview.addEventListener("unload", onControlnetModelChange) onControlnetModelChange() -function onControlImageFilterChange() { - let filterId = controlImageFilterField.value - if (filterId.includes("openpose")) { - controlnetModelField.value = "control_v11p_sd15_openpose" - } else if (filterId === "canny") { - controlnetModelField.value = "control_v11p_sd15_canny" - } else if (filterId === "mlsd") { - controlnetModelField.value = "control_v11p_sd15_mlsd" - } else if (filterId === "mlsd") { - controlnetModelField.value = "control_v11p_sd15_mlsd" - } else if (filterId.includes("scribble")) { - controlnetModelField.value = "control_v11p_sd15_scribble" - } else if (filterId.includes("softedge")) { - controlnetModelField.value = "control_v11p_sd15_softedge" - } else if (filterId === "normal_bae") { - controlnetModelField.value = "control_v11p_sd15_normalbae" - } else if (filterId.includes("depth")) { - controlnetModelField.value = "control_v11f1p_sd15_depth" - } else if (filterId === "lineart_anime") { - controlnetModelField.value = "control_v11p_sd15s2_lineart_anime" - } else if (filterId.includes("lineart")) { - controlnetModelField.value = "control_v11p_sd15_lineart" - } else if (filterId === "shuffle") { - controlnetModelField.value = "control_v11e_sd15_shuffle" - } else if (filterId === "segment") { - controlnetModelField.value = "control_v11p_sd15_seg" - } -} -controlImageFilterField.addEventListener("change", onControlImageFilterChange) -onControlImageFilterChange() +// function onControlImageFilterChange() { +// let filterId = controlImageFilterField.value +// if (filterId.includes("openpose")) { +// controlnetModelField.value = "control_v11p_sd15_openpose" +// } else if (filterId === "canny") { +// controlnetModelField.value = "control_v11p_sd15_canny" +// } else if (filterId === "mlsd") { +// controlnetModelField.value = "control_v11p_sd15_mlsd" +// } else if (filterId === "mlsd") { +// controlnetModelField.value = "control_v11p_sd15_mlsd" +// } else if (filterId.includes("scribble")) { +// controlnetModelField.value = "control_v11p_sd15_scribble" +// } else if (filterId.includes("softedge")) { +// controlnetModelField.value = "control_v11p_sd15_softedge" +// } else if (filterId === "normal_bae") { +// controlnetModelField.value = "control_v11p_sd15_normalbae" +// } else if (filterId.includes("depth")) { +// controlnetModelField.value = "control_v11f1p_sd15_depth" +// } else if (filterId === "lineart_anime") { +// controlnetModelField.value = "control_v11p_sd15s2_lineart_anime" +// } else if (filterId.includes("lineart")) { +// controlnetModelField.value = "control_v11p_sd15_lineart" +// } else if (filterId === "shuffle") { +// controlnetModelField.value = "control_v11e_sd15_shuffle" +// } else if (filterId === "segment") { +// controlnetModelField.value = "control_v11p_sd15_seg" +// } +// } +// controlImageFilterField.addEventListener("change", onControlImageFilterChange) +// onControlImageFilterChange() upscaleModelField.disabled = !useUpscalingField.checked upscaleAmountField.disabled = !useUpscalingField.checked From 9a12a8618cb426c99e4c6175d8234339fc94a5fe Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Oct 2024 10:54:58 +0530 Subject: [PATCH 07/64] First working version of dynamic backends, with Forge and ed_diffusers (v3) and ed_classic (v2). Does not auto-install Forge yet --- ui/easydiffusion/app.py | 23 +- ui/easydiffusion/backend_manager.py | 105 ++++ ui/easydiffusion/backends/ed_classic.py | 27 + ui/easydiffusion/backends/ed_diffusers.py | 27 + ui/easydiffusion/backends/sdkit_common.py | 237 ++++++++ ui/easydiffusion/backends/webui/__init__.py | 155 +++++ ui/easydiffusion/backends/webui/impl.py | 639 ++++++++++++++++++++ ui/easydiffusion/easydb/schemas.py | 1 - ui/easydiffusion/model_manager.py | 72 ++- ui/easydiffusion/runtime.py | 34 +- ui/easydiffusion/server.py | 24 +- ui/easydiffusion/task_manager.py | 15 +- ui/easydiffusion/tasks/filter_images.py | 81 +-- ui/easydiffusion/tasks/render_images.py | 224 +++---- ui/easydiffusion/types.py | 67 +- ui/easydiffusion/utils/__init__.py | 3 +- ui/easydiffusion/utils/nsfw_checker.py | 80 +++ ui/easydiffusion/utils/save_utils.py | 2 +- ui/index.html | 131 +++- ui/main.py | 3 - ui/media/css/auto-save.css | 3 +- ui/media/css/main.css | 15 +- ui/media/js/main.js | 15 +- ui/media/js/parameters.js | 88 +-- 24 files changed, 1715 insertions(+), 356 deletions(-) create mode 100644 ui/easydiffusion/backend_manager.py create mode 100644 ui/easydiffusion/backends/ed_classic.py create mode 100644 ui/easydiffusion/backends/ed_diffusers.py create mode 100644 ui/easydiffusion/backends/sdkit_common.py create mode 100644 ui/easydiffusion/backends/webui/__init__.py create mode 100644 ui/easydiffusion/backends/webui/impl.py create mode 100644 ui/easydiffusion/utils/nsfw_checker.py diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index a1f607ef..ec856991 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -11,7 +11,7 @@ from ruamel.yaml import YAML import urllib import warnings -from easydiffusion import task_manager +from easydiffusion import task_manager, backend_manager from easydiffusion.utils import log from rich.logging import RichHandler from rich.console import Console @@ -60,7 +60,7 @@ APP_CONFIG_DEFAULTS = { "ui": { "open_browser_on_start": True, }, - "use_v3_engine": True, + "backend": "ed_diffusers", } IMAGE_EXTENSIONS = [ @@ -108,6 +108,8 @@ def init(): if config_models_dir is not None and config_models_dir != "": MODELS_DIR = config_models_dir + backend_manager.start_backend() + def init_render_threads(): load_server_plugins() @@ -124,9 +126,9 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): shutil.move(config_legacy_yaml, config_yaml_path) def set_config_on_startup(config: dict): - if getConfig.__use_v3_engine_on_startup is None: - getConfig.__use_v3_engine_on_startup = config.get("use_v3_engine", True) - config["config_on_startup"] = {"use_v3_engine": getConfig.__use_v3_engine_on_startup} + if getConfig.__use_backend_on_startup is None: + getConfig.__use_backend_on_startup = config.get("backend", "ed_diffusers") + config["config_on_startup"] = {"backend": getConfig.__use_backend_on_startup} if os.path.isfile(config_yaml_path): try: @@ -144,6 +146,15 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): else: config["net"]["listen_to_network"] = True + if "backend" not in config: + if "use_v3_engine" in config: + config["backend"] = "ed_diffusers" if config["use_v3_engine"] else "ed_classic" + else: + config["backend"] = "ed_diffusers" + # this default will need to be smarter when WebUI becomes the main backend, but needs to maintain backwards + # compatibility with existing ED 3.0 installations that haven't opted into the WebUI backend, and haven't + # set a "use_v3_engine" flag in their config + set_config_on_startup(config) return config @@ -174,7 +185,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): return default_val -getConfig.__use_v3_engine_on_startup = None +getConfig.__use_backend_on_startup = None def setConfig(config): diff --git a/ui/easydiffusion/backend_manager.py b/ui/easydiffusion/backend_manager.py new file mode 100644 index 00000000..1c280db4 --- /dev/null +++ b/ui/easydiffusion/backend_manager.py @@ -0,0 +1,105 @@ +import os +import ast +import sys +import importlib.util +import traceback + +from easydiffusion.utils import log + +backend = None +curr_backend_name = None + + +def is_valid_backend(file_path): + with open(file_path, "r", encoding="utf-8") as file: + node = ast.parse(file.read()) + + # Check for presence of a dictionary named 'ed_info' + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name) and target.id == "ed_info": + return True + return False + + +def find_valid_backends(root_dir) -> dict: + backends_path = os.path.join(root_dir, "backends") + valid_backends = {} + + if not os.path.exists(backends_path): + return valid_backends + + for item in os.listdir(backends_path): + item_path = os.path.join(backends_path, item) + + if os.path.isdir(item_path): + init_file = os.path.join(item_path, "__init__.py") + if os.path.exists(init_file) and is_valid_backend(init_file): + valid_backends[item] = item_path + elif item.endswith(".py"): + if is_valid_backend(item_path): + backend_name = os.path.splitext(item)[0] # strip the .py extension + valid_backends[backend_name] = item_path + + return valid_backends + + +def load_backend_module(backend_name, backend_dict): + if backend_name not in backend_dict: + raise ValueError(f"Backend '{backend_name}' not found.") + + module_path = backend_dict[backend_name] + + mod_dir = os.path.dirname(module_path) + + sys.path.insert(0, mod_dir) + + # If it's a package (directory), add its parent directory to sys.path + if os.path.isdir(module_path): + module_path = os.path.join(module_path, "__init__.py") + + spec = importlib.util.spec_from_file_location(backend_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if mod_dir in sys.path: + sys.path.remove(mod_dir) + + log.info(f"Loaded backend: {module}") + + return module + + +def start_backend(): + global backend, curr_backend_name + + from easydiffusion.app import getConfig, ROOT_DIR + + curr_dir = os.path.dirname(__file__) + + backends = find_valid_backends(curr_dir) + plugin_backends = find_valid_backends(ROOT_DIR) + backends.update(plugin_backends) + + config = getConfig() + backend_name = config["backend"] + + if backend_name not in backends: + raise RuntimeError( + f"Couldn't find the backend configured in config.yaml: {backend_name}. Please check the name!" + ) + + if backend is not None and backend_name != curr_backend_name: + try: + backend.stop_backend() + except: + log.exception(traceback.format_exc()) + + log.info(f"Loading backend: {backend_name}") + backend = load_backend_module(backend_name, backends) + + try: + backend.start_backend() + except: + log.exception(traceback.format_exc()) diff --git a/ui/easydiffusion/backends/ed_classic.py b/ui/easydiffusion/backends/ed_classic.py new file mode 100644 index 00000000..c9cf745e --- /dev/null +++ b/ui/easydiffusion/backends/ed_classic.py @@ -0,0 +1,27 @@ +from sdkit_common import ( + start_backend, + stop_backend, + install_backend, + uninstall_backend, + create_sdkit_context, + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + +ed_info = { + "name": "Classic backend for Easy Diffusion v2", + "version": (1, 0, 0), + "type": "backend", +} + + +def create_context(): + return create_sdkit_context(use_diffusers=False) diff --git a/ui/easydiffusion/backends/ed_diffusers.py b/ui/easydiffusion/backends/ed_diffusers.py new file mode 100644 index 00000000..c905652d --- /dev/null +++ b/ui/easydiffusion/backends/ed_diffusers.py @@ -0,0 +1,27 @@ +from sdkit_common import ( + start_backend, + stop_backend, + install_backend, + uninstall_backend, + create_sdkit_context, + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + +ed_info = { + "name": "Diffusers Backend for Easy Diffusion v3", + "version": (1, 0, 0), + "type": "backend", +} + + +def create_context(): + return create_sdkit_context(use_diffusers=True) diff --git a/ui/easydiffusion/backends/sdkit_common.py b/ui/easydiffusion/backends/sdkit_common.py new file mode 100644 index 00000000..d7a49c3e --- /dev/null +++ b/ui/easydiffusion/backends/sdkit_common.py @@ -0,0 +1,237 @@ +from sdkit import Context + +from easydiffusion.types import UserInitiatedStop + +from sdkit.utils import ( + diffusers_latent_samples_to_images, + gc, + img_to_base64_str, + latent_samples_to_images, +) + +opts = {} + + +def install_backend(): + pass + + +def start_backend(): + print("Started sdkit backend") + + +def stop_backend(): + pass + + +def uninstall_backend(): + pass + + +def create_sdkit_context(use_diffusers): + c = Context() + c.test_diffusers = use_diffusers + return c + + +def ping(timeout=1): + return True + + +def load_model(context, model_type, **kwargs): + from sdkit.models import load_model + + load_model(context, model_type, **kwargs) + + +def unload_model(context, model_type, **kwargs): + from sdkit.models import unload_model + + unload_model(context, model_type, **kwargs) + + +def set_options(context, **kwargs): + if "vae_tiling" in kwargs and context.test_diffusers: + pipe = context.models["stable-diffusion"]["default"] + vae_tiling = kwargs["vae_tiling"] + + if vae_tiling: + if hasattr(pipe, "enable_vae_tiling"): + pipe.enable_vae_tiling() + else: + if hasattr(pipe, "disable_vae_tiling"): + pipe.disable_vae_tiling() + + for key in ( + "output_format", + "output_quality", + "output_lossless", + "stream_image_progress", + "stream_image_progress_interval", + ): + if key in kwargs: + opts[key] = kwargs[key] + + +def generate_images( + context: Context, + callback=None, + controlnet_filter=None, + output_type="pil", + **req, +): + from sdkit.generate import generate_images + + if req["init_image"] is not None and not context.test_diffusers: + req["sampler_name"] = "ddim" + + gc(context) + + context.stop_processing = False + + if req["control_image"] and controlnet_filter: + controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter) + req["control_image"] = filter_images(context, req["control_image"], controlnet_filter)[0] + + callback = make_step_callback(context, callback) + + try: + images = generate_images(context, callback=callback, **req) + except UserInitiatedStop: + images = [] + if context.partial_x_samples is not None: + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) + finally: + if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: + if not context.test_diffusers: + del context.partial_x_samples + context.partial_x_samples = None + + gc(context) + + if output_type == "base64": + output_format = opts.get("output_format", "jpeg") + output_quality = opts.get("output_quality", 75) + output_lossless = opts.get("output_lossless", False) + images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images] + + return images + + +def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"): + gc(context) + + if "nsfw_checker" in filters: + filters.remove("nsfw_checker") # handled by ED directly + + images = _filter_images(context, images, filters, filter_params) + + if input_type == "base64": + output_format = opts.get("output_format", "jpg") + output_quality = opts.get("output_quality", 75) + output_lossless = opts.get("output_lossless", False) + images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images] + + return images + + +def _filter_images(context, images, filters, filter_params={}): + from sdkit.filter import apply_filters + + filters = filters if isinstance(filters, list) else [filters] + filters = convert_ED_controlnet_filter_name(filters) + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + previous_state = before_filter(context, filter_name, params) + + try: + images = apply_filters(context, filter_name, images, **params) + finally: + after_filter(context, filter_name, params, previous_state) + + return images + + +def before_filter(context, filter_name, filter_params): + if filter_name == "codeformer": + from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use + + default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] + prev_realesrgan_path = None + + upscale_faces = filter_params.get("upscale_faces", False) + if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: + prev_realesrgan_path = context.model_paths.get("realesrgan") + context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") + load_model(context, "realesrgan") + + return prev_realesrgan_path + + +def after_filter(context, filter_name, filter_params, previous_state): + if filter_name == "codeformer": + prev_realesrgan_path = previous_state + if prev_realesrgan_path: + context.model_paths["realesrgan"] = prev_realesrgan_path + load_model(context, "realesrgan") + + +def get_url(): + pass + + +def stop_rendering(context): + context.stop_processing = True + + +def refresh_models(): + pass + + +def list_controlnet_filters(): + from sdkit.models.model_loader.controlnet_filters import filters as cn_filters + + return cn_filters + + +def make_step_callback(context, callback): + def on_step(x_samples, i, *args): + stream_image_progress = opts.get("stream_image_progress", False) + stream_image_progress_interval = opts.get("stream_image_progress_interval", 3) + + if context.test_diffusers: + context.partial_x_samples = (x_samples, args[0]) + else: + context.partial_x_samples = x_samples + + if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) + else: + images = None + + if callback: + callback(images, i, *args) + + if context.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + + return on_step + + +def convert_ED_controlnet_filter_name(filter): + def cn(n): + if n.startswith("controlnet_"): + return n[len("controlnet_") :] + return n + + if isinstance(filter, list): + return [cn(f) for f in filter] + return cn(filter) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py new file mode 100644 index 00000000..f78d1164 --- /dev/null +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -0,0 +1,155 @@ +import os +import platform +import subprocess +import threading +from threading import local +import psutil + +from easydiffusion.app import ROOT_DIR, getConfig + +from . import impl +from .impl import ( + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + + +ed_info = { + "name": "WebUI backend for Easy Diffusion", + "version": (1, 0, 0), + "type": "backend", +} + +BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) +SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") +WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") + +backend_process = None + + +def install_backend(): + pass + + +def start_backend(): + config = getConfig() + backend_config = config.get("backend_config", {}) + + if not os.path.exists(BACKEND_DIR): + install_backend() + + impl.WEBUI_HOST = backend_config.get("host", "localhost") + impl.WEBUI_PORT = backend_config.get("port", "7860") + + env = dict(os.environ) + env.update(get_env()) + + def target(): + global backend_process + + cmd = "webui.bat" if platform.system() == "Windows" else "webui.sh" + print("starting", cmd, WEBUI_DIR) + backend_process = subprocess.Popen([cmd], shell=True, cwd=WEBUI_DIR, env=env) + + backend_thread = threading.Thread(target=target) + backend_thread.start() + + +def stop_backend(): + global backend_process + + if backend_process: + kill(backend_process.pid) + + backend_process = None + + +def uninstall_backend(): + pass + + +def create_context(): + context = local() + + # temp hack, throws an attribute not found error otherwise + context.device = "cuda:0" + context.half_precision = True + context.vram_usage_level = None + + context.models = {} + context.model_paths = {} + context.model_configs = {} + context.device_name = None + context.vram_optimizations = set() + context.vram_usage_level = "balanced" + context.test_diffusers = False + context.enable_codeformer = False + + return context + + +def get_env(): + dir = os.path.abspath(SYSTEM_DIR) + + if not os.path.exists(dir): + raise RuntimeError("The system folder is missing!") + + config = getConfig() + models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) + embeddings_dir = os.path.join(models_dir, "embeddings") + + env_entries = { + "PATH": [ + f"{dir}/git/bin", + f"{dir}/python", + f"{dir}/python/Library/bin", + f"{dir}/python/Scripts", + f"{dir}/python/Library/usr/bin", + ], + "PYTHONPATH": [ + f"{dir}/python", + f"{dir}/python/lib/site-packages", + f"{dir}/python/lib/python3.10/site-packages", + ], + "PYTHONHOME": [], + "PY_LIBS": [f"{dir}/python/Scripts/Lib", f"{dir}/python/Scripts/Lib/site-packages"], + "PY_PIP": [f"{dir}/python/Scripts"], + "PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"], + "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], + "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], + "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'], + "SKIP_VENV": ["1"], + "SD_WEBUI_RESTARTING": ["1"], + "PYTHON": [f"{dir}/python/python"], + "GIT": [f"{dir}/git/bin/git"], + } + + if platform.system() == "Windows": + env_entries["PYTHONNOUSERSITE"] = ["1"] + else: + env_entries["PYTHONNOUSERSITE"] = ["y"] + + env = {} + for key, paths in env_entries.items(): + paths = [p.replace("/", os.path.sep) for p in paths] + paths = os.pathsep.join(paths) + + env[key] = paths + + return env + + +# https://stackoverflow.com/a/25134985 +def kill(proc_pid): + process = psutil.Process(proc_pid) + for proc in process.children(recursive=True): + proc.kill() + process.kill() diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py new file mode 100644 index 00000000..c90970ec --- /dev/null +++ b/ui/easydiffusion/backends/webui/impl.py @@ -0,0 +1,639 @@ +import os +import requests +from requests.exceptions import ConnectTimeout +from typing import Union, List +from threading import local as Context +from threading import Thread +import uuid +import time +from copy import deepcopy + +from sdkit.utils import base64_str_to_img, img_to_base64_str + +WEBUI_HOST = "localhost" +WEBUI_PORT = "7860" + +DEFAULT_WEBUI_OPTIONS = {"show_progress_every_n_steps": 3, "show_progress_grid": True, "live_previews_enable": False} + + +webui_opts: dict = None + + +curr_models = { + "stable-diffusion": None, + "vae": None, +} + + +def set_options(context, **kwargs): + changed_opts = {} + + opts_mapping = { + "stream_image_progress": ("live_previews_enable", bool), + "stream_image_progress_interval": ("show_progress_every_n_steps", int), + "clip_skip": ("CLIP_stop_at_last_layers", int), + "clip_skip_sdxl": ("sdxl_clip_l_skip", bool), + "output_format": ("samples_format", str), + } + + for ed_key, webui_key in opts_mapping.items(): + webui_key, webui_type = webui_key + + if ed_key in kwargs and (webui_opts is None or webui_opts.get(webui_key, False) != webui_type(kwargs[ed_key])): + changed_opts[webui_key] = webui_type(kwargs[ed_key]) + + if changed_opts: + changed_opts["sd_model_checkpoint"] = curr_models["stable-diffusion"] + + print(f"Got options: {kwargs}. Sending options: {changed_opts}") + + try: + res = webui_post("/sdapi/v1/options", json=changed_opts) + if res.status_code != 200: + raise Exception(res.text) + + webui_opts.update(changed_opts) + except Exception as e: + print(f"Error setting options: {e}") + + +def ping(timeout=1): + "timeout (in seconds)" + + global webui_opts + + try: + webui_get("/internal/ping", timeout=timeout) + + if webui_opts is None: + try: + res = webui_post("/sdapi/v1/options", json=DEFAULT_WEBUI_OPTIONS) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + print(f"Error setting options: {e}") + + try: + res = webui_get("/sdapi/v1/options") + if res.status_code != 200: + raise Exception(res.text) + + webui_opts = res.json() + except Exception as e: + print(f"Error setting options: {e}") + + return True + except ConnectTimeout as e: + raise TimeoutError(e) + + +def load_model(context, model_type, **kwargs): + model_path = context.model_paths[model_type] + + if webui_opts is None: + print("Server not ready, can't set the model") + return + + if model_type == "stable-diffusion": + model_name = os.path.basename(model_path) + model_name = os.path.splitext(model_name)[0] + print(f"setting sd model: {model_name}") + if curr_models[model_type] != model_name: + try: + res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name}) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) + + curr_models[model_type] = model_name + elif model_type == "vae": + if curr_models[model_type] != model_path: + vae_model = [model_path] if model_path else [] + + opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model} + print("setting opts 2", opts) + + try: + res = webui_post("/sdapi/v1/options", json=opts) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) + + curr_models[model_type] = model_path + + +def unload_model(context, model_type, **kwargs): + pass + + +def generate_images( + context: Context, + prompt: str = "", + negative_prompt: str = "", + seed: int = 42, + width: int = 512, + height: int = 512, + num_outputs: int = 1, + num_inference_steps: int = 25, + guidance_scale: float = 7.5, + init_image=None, + init_image_mask=None, + control_image=None, + control_alpha=1.0, + controlnet_filter=None, + prompt_strength: float = 0.8, + preserve_init_image_color_profile=False, + strict_mask_border=False, + sampler_name: str = "euler_a", + hypernetwork_strength: float = 0, + tiling=None, + lora_alpha: Union[float, List[float]] = 0, + sampler_params={}, + callback=None, + output_type="pil", +): + + task_id = str(uuid.uuid4()) + + sampler_name = convert_ED_sampler_names(sampler_name) + controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter) + + cmd = { + "force_task_id": task_id, + "prompt": prompt, + "negative_prompt": negative_prompt, + "sampler_name": sampler_name, + "scheduler": "simple", + "steps": num_inference_steps, + "seed": seed, + "cfg_scale": guidance_scale, + "batch_size": num_outputs, + "width": width, + "height": height, + } + + if init_image: + cmd["init_images"] = [init_image] + cmd["denoising_strength"] = prompt_strength + if init_image_mask: + cmd["mask"] = init_image_mask + cmd["include_init_images"] = True + cmd["inpainting_fill"] = 1 + cmd["initial_noise_multiplier"] = 1 + cmd["inpaint_full_res"] = 1 + + if context.model_paths.get("lora"): + lora_model = context.model_paths["lora"] + lora_model = lora_model if isinstance(lora_model, list) else [lora_model] + lora_alpha = lora_alpha if isinstance(lora_alpha, list) else [lora_alpha] + + for lora, alpha in zip(lora_model, lora_alpha): + lora = os.path.basename(lora) + lora = os.path.splitext(lora)[0] + cmd["prompt"] += f" " + + if controlnet_filter and control_image and context.model_paths.get("controlnet"): + controlnet_model = context.model_paths["controlnet"] + + model_hash = auto1111_hash(controlnet_model) + controlnet_model = os.path.basename(controlnet_model) + controlnet_model = os.path.splitext(controlnet_model)[0] + print(f"setting controlnet model: {controlnet_model}") + controlnet_model = f"{controlnet_model} [{model_hash}]" + + cmd["alwayson_scripts"] = { + "controlnet": { + "args": [ + { + "image": control_image, + "weight": control_alpha, + "module": controlnet_filter, + "model": controlnet_model, + "resize_mode": "Crop and Resize", + "threshold_a": 50, + "threshold_b": 130, + } + ] + } + } + + operation_to_apply = "img2img" if init_image else "txt2img" + + stream_image_progress = webui_opts.get("live_previews_enable", False) + + progress_thread = Thread( + target=image_progress_thread, args=(task_id, callback, stream_image_progress, num_outputs, num_inference_steps) + ) + progress_thread.start() + + print(f"task id: {task_id}") + print_request(operation_to_apply, cmd) + + res = webui_post(f"/sdapi/v1/{operation_to_apply}", json=cmd) + if res.status_code == 200: + res = res.json() + else: + raise Exception( + "The engine failed while generating this image. Please check the logs in the command-line window for more details." + ) + + import json + + print(json.loads(res["info"])["infotexts"]) + + images = res["images"] + if output_type == "pil": + images = [base64_str_to_img(img) for img in images] + elif output_type == "base64": + images = [base64_buffer_to_base64_img(img) for img in images] + + return images + + +def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"): + """ + * context: Context + * images: str or PIL.Image or list of str/PIL.Image - image to filter. if a string is passed, it needs to be a base64-encoded image + * filters: filter_type (string) or list of strings + * filter_params: dict + + returns: [PIL.Image] - list of filtered images + """ + images = images if isinstance(images, list) else [images] + filters = filters if isinstance(filters, list) else [filters] + + if "nsfw_checker" in filters: + filters.remove("nsfw_checker") # handled by ED directly + + args = {} + controlnet_filters = [] + + print(filter_params) + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + if filter_name == "gfpgan": + args["gfpgan_visibility"] = 1 + + if filter_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + args["upscaler_1"] = params.get("upscaler", "RealESRGAN_x4plus") + args["upscaling_resize"] = params.get("scale", 4) + + if args["upscaler_1"] == "RealESRGAN_x4plus": + args["upscaler_1"] = "R-ESRGAN 4x+" + elif args["upscaler_1"] == "RealESRGAN_x4plus_anime_6B": + args["upscaler_1"] = "R-ESRGAN 4x+ Anime6B" + + if filter_name == "codeformer": + args["codeformer_visibility"] = 1 + args["codeformer_weight"] = params.get("codeformer_fidelity", 0.5) + + if filter_name.startswith("controlnet_"): + filter_name = convert_ED_controlnet_filter_name(filter_name) + controlnet_filters.append(filter_name) + + print(f"filtering {len(images)} images with {args}. {controlnet_filters=}") + + if len(filters) > len(controlnet_filters): + filtered_images = extra_batch_images(images, input_type=input_type, **args) + else: + filtered_images = images + + for filter_name in controlnet_filters: + filtered_images = controlnet_filter(filtered_images, module=filter_name, input_type=input_type) + + return filtered_images + + +def get_url(): + return f"//{WEBUI_HOST}:{WEBUI_PORT}/?__theme=dark" + + +def stop_rendering(context): + try: + res = webui_post("/sdapi/v1/interrupt") + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + print(f"Error interrupting webui: {e}") + + +def refresh_models(): + def make_refresh_call(type): + try: + webui_post(f"/sdapi/v1/refresh-{type}") + except: + pass + + try: + for type in ("checkpoints", "vae"): + t = Thread(target=make_refresh_call, args=(type,)) + t.start() + except Exception as e: + print(f"Error refreshing models: {e}") + + +def list_controlnet_filters(): + return [ + "openpose", + "openpose_face", + "openpose_faceonly", + "openpose_hand", + "openpose_full", + "animal_openpose", + "densepose_parula (black bg & blue torso)", + "densepose (pruple bg & purple torso)", + "dw_openpose_full", + "mediapipe_face", + "instant_id_face_keypoints", + "InsightFace+CLIP-H (IPAdapter)", + "InsightFace (InstantID)", + "canny", + "mlsd", + "scribble_hed", + "scribble_hedsafe", + "scribble_pidinet", + "scribble_pidsafe", + "scribble_xdog", + "softedge_hed", + "softedge_hedsafe", + "softedge_pidinet", + "softedge_pidsafe", + "softedge_teed", + "normal_bae", + "depth_midas", + "normal_midas", + "depth_zoe", + "depth_leres", + "depth_leres++", + "depth_anything_v2", + "depth_anything", + "depth_hand_refiner", + "depth_marigold", + "lineart_coarse", + "lineart_realistic", + "lineart_anime", + "lineart_standard (from white bg & black line)", + "lineart_anime_denoise", + "reference_adain", + "reference_only", + "reference_adain+attn", + "tile_colorfix", + "tile_resample", + "tile_colorfix+sharp", + "CLIP-ViT-H (IPAdapter)", + "CLIP-G (Revision)", + "CLIP-G (Revision ignore prompt)", + "CLIP-ViT-bigG (IPAdapter)", + "InsightFace+CLIP-H (IPAdapter)", + "inpaint_only", + "inpaint_only+lama", + "inpaint_global_harmonious", + "seg_ufade20k", + "seg_ofade20k", + "seg_anime_face", + "seg_ofcoco", + "shuffle", + "segment", + "invert (from white bg & black line)", + "threshold", + "t2ia_sketch_pidi", + "t2ia_color_grid", + "recolor_intensity", + "recolor_luminance", + "blur_gaussian", + ] + + +def controlnet_filter(images, module="none", processor_res=512, threshold_a=64, threshold_b=64, input_type="pil"): + if input_type == "pil": + images = [img_to_base64_str(x) for x in images] + + payload = { + "controlnet_module": module, + "controlnet_input_images": images, + "controlnet_processor_res": processor_res, + "controlnet_threshold_a": threshold_a, + "controlnet_threshold_b": threshold_b, + } + res = webui_post("/controlnet/detect", json=payload) + res = res.json() + filtered_images = res["images"] + + if input_type == "pil": + filtered_images = [base64_str_to_img(img) for img in filtered_images] + elif input_type == "base64": + filtered_images = [base64_buffer_to_base64_img(img) for img in filtered_images] + + return filtered_images + + +def image_progress_thread(task_id, callback, stream_image_progress, total_images, total_steps): + from PIL import Image + + last_preview_id = -1 + + EMPTY_IMAGE = Image.new("RGB", (1, 1)) + + while True: + res = webui_post( + f"/internal/progress", + json={"id_task": task_id, "live_preview": stream_image_progress, "id_live_preview": last_preview_id}, + ) + if res.status_code == 200: + res = res.json() + + last_preview_id = res["id_live_preview"] + + if res["progress"] is not None: + step_num = int(res["progress"] * total_steps) + + if res["live_preview"] is not None: + img = res["live_preview"] + img = base64_str_to_img(img) + images = [EMPTY_IMAGE] * total_images + images[0] = img + else: + images = None + + callback(images, step_num) + + if res["completed"] == True: + print("Complete!") + break + + time.sleep(0.5) + + +def webui_get(uri, *args, **kwargs): + url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}" + return requests.get(url, *args, **kwargs) + + +def webui_post(uri, *args, **kwargs): + url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}" + return requests.post(url, *args, **kwargs) + + +def print_request(operation_to_apply, args): + args = deepcopy(args) + if "init_images" in args: + args["init_images"] = ["img" for _ in args["init_images"]] + if "mask" in args: + args["mask"] = "mask_img" + + controlnet_args = args.get("alwayson_scripts", {}).get("controlnet", {}).get("args", []) + if controlnet_args: + controlnet_args[0]["image"] = "control_image" + + print(f"operation: {operation_to_apply}, args: {args}") + + +def auto1111_hash(file_path): + import hashlib + + with open(file_path, "rb") as f: + f.seek(0x100000) + b = f.read(0x10000) + return hashlib.sha256(b).hexdigest()[:8] + + +def extra_batch_images( + images, # list of PIL images + name_list=None, # list of image names + resize_mode=0, + show_extras_results=True, + gfpgan_visibility=0, + codeformer_visibility=0, + codeformer_weight=0, + upscaling_resize=2, + upscaling_resize_w=512, + upscaling_resize_h=512, + upscaling_crop=True, + upscaler_1="None", + upscaler_2="None", + extras_upscaler_2_visibility=0, + upscale_first=False, + use_async=False, + input_type="pil", +): + if name_list is not None: + if len(name_list) != len(images): + raise RuntimeError("len(images) != len(name_list)") + else: + name_list = [f"image{i + 1:05}" for i in range(len(images))] + + if input_type == "pil": + images = [img_to_base64_str(x) for x in images] + + image_list = [] + for name, image in zip(name_list, images): + image_list.append({"data": image, "name": name}) + + payload = { + "resize_mode": resize_mode, + "show_extras_results": show_extras_results, + "gfpgan_visibility": gfpgan_visibility, + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + "upscaling_resize": upscaling_resize, + "upscaling_resize_w": upscaling_resize_w, + "upscaling_resize_h": upscaling_resize_h, + "upscaling_crop": upscaling_crop, + "upscaler_1": upscaler_1, + "upscaler_2": upscaler_2, + "extras_upscaler_2_visibility": extras_upscaler_2_visibility, + "upscale_first": upscale_first, + "imageList": image_list, + } + + res = webui_post("/sdapi/v1/extra-batch-images", json=payload) + if res.status_code == 200: + res = res.json() + else: + raise Exception( + "The engine failed while filtering this image. Please check the logs in the command-line window for more details." + ) + + images = res["images"] + + if input_type == "pil": + images = [base64_str_to_img(img) for img in images] + elif input_type == "base64": + images = [base64_buffer_to_base64_img(img) for img in images] + + return images + + +def base64_buffer_to_base64_img(img): + output_format = webui_opts.get("samples_format", "jpeg") + mime_type = f"image/{output_format.lower()}" + return f"data:{mime_type};base64," + img + + +def convert_ED_sampler_names(sampler_name): + name_mapping = { + "dpmpp_2m": "DPM++ 2M", + "dpmpp_sde": "DPM++ SDE", + "dpmpp_2m_sde": "DPM++ 2M SDE", + "dpmpp_2m_sde_heun": "DPM++ 2M SDE Heun", + "dpmpp_2s_a": "DPM++ 2S a", + "dpmpp_3m_sde": "DPM++ 3M SDE", + "euler_a": "Euler a", + "euler": "Euler", + "lms": "LMS", + "heun": "Heun", + "dpm2": "DPM2", + "dpm2_a": "DPM2 a", + "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", + "restart": "Restart", + "heun_pp2": "HeunPP2", + "ipndm": "IPNDM", + "ipndm_v": "IPNDM_V", + "deis": "DEIS", + "ddim": "DDIM", + "ddim_cfgpp": "DDIM CFG++", + "plms": "PLMS", + "unipc": "UniPC", + "lcm": "LCM", + "ddpm": "DDPM", + "forge_flux_realistic": "[Forge] Flux Realistic", + "forge_flux_realistic_slow": "[Forge] Flux Realistic (Slow)", + # deprecated samplers in 3.5 + "dpm_solver_stability": None, + "unipc_snr": None, + "unipc_tu": None, + "unipc_snr_2": None, + "unipc_tu_2": None, + "unipc_tq": None, + } + return name_mapping.get(sampler_name) + + +def convert_ED_controlnet_filter_name(filter): + if filter is None: + return None + + def cn(n): + if n.startswith("controlnet_"): + return n[len("controlnet_") :] + return n + + mapping = { + "controlnet_scribble_hedsafe": None, + "controlnet_scribble_pidsafe": None, + "controlnet_softedge_pidsafe": "controlnet_softedge_pidisafe", + "controlnet_normal_bae": "controlnet_normalbae", + "controlnet_segment": None, + } + if isinstance(filter, list): + return [cn(mapping.get(f, f)) for f in filter] + return cn(mapping.get(filter, filter)) diff --git a/ui/easydiffusion/easydb/schemas.py b/ui/easydiffusion/easydb/schemas.py index 68bc04e2..c6b7337f 100644 --- a/ui/easydiffusion/easydb/schemas.py +++ b/ui/easydiffusion/easydb/schemas.py @@ -33,4 +33,3 @@ class Bucket(BucketBase): class Config: orm_mode = True - diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 26f05242..d821db41 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -8,7 +8,7 @@ from easydiffusion import app from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db +from sdkit.models import scan_model, download_model, get_model_info_from_db from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.utils import hash_file_quick from sdkit.models.model_loader.embeddings import get_embedding_token @@ -25,15 +25,15 @@ KNOWN_MODEL_TYPES = [ "controlnet", ] MODEL_EXTENSIONS = { - "stable-diffusion": [".ckpt", ".safetensors"], - "vae": [".vae.pt", ".ckpt", ".safetensors"], - "hypernetwork": [".pt", ".safetensors"], + "stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"], + "vae": [".vae.pt", ".ckpt", ".safetensors", ".sft"], + "hypernetwork": [".pt", ".safetensors", ".sft"], "gfpgan": [".pth"], "realesrgan": [".pth"], - "lora": [".ckpt", ".safetensors", ".pt"], + "lora": [".ckpt", ".safetensors", ".sft", ".pt"], "codeformer": [".pth"], - "embeddings": [".pt", ".bin", ".safetensors"], - "controlnet": [".pth", ".safetensors"], + "embeddings": [".pt", ".bin", ".safetensors", ".sft"], + "controlnet": [".pth", ".safetensors", ".sft"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -63,6 +63,7 @@ def init(): def load_default_models(context: Context): from easydiffusion import runtime + from easydiffusion.backend_manager import backend runtime.set_vram_optimizations(context) @@ -70,7 +71,7 @@ def load_default_models(context: Context): for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False) try: - load_model( + backend.load_model( context, model_type, scan_model=context.model_paths[model_type] != None @@ -92,8 +93,10 @@ def load_default_models(context: Context): def unload_all(context: Context): + from easydiffusion.backend_manager import backend + for model_type in KNOWN_MODEL_TYPES: - unload_model(context, model_type) + backend.unload_model(context, model_type) if model_type in context.model_load_errors: del context.model_load_errors[model_type] @@ -154,6 +157,8 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []): + from easydiffusion.backend_manager import backend + models_to_reload = { model_type: path for model_type, path in models_data.model_paths.items() @@ -175,7 +180,7 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req - action_fn = unload_model if context.model_paths[model_type] is None else load_model + action_fn = backend.unload_model if context.model_paths[model_type] is None else backend.load_model extra_params = models_data.model_params.get(model_type, {}) try: action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already @@ -183,14 +188,23 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models del context.model_load_errors[model_type] except Exception as e: log.exception(e) - if action_fn == load_model: + if action_fn == backend.load_model: context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks def resolve_model_paths(models_data: ModelsData): model_paths = models_data.model_paths + skip_models = cn_filters + [ + "latent_upscaler", + "nsfw_checker", + "esrgan_4x", + "lanczos", + "nearest", + "scunet", + "swinir", + ] + for model_type in model_paths: - skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"] if model_type in skip_models: # doesn't use model paths continue if model_type == "codeformer" and model_paths[model_type]: @@ -320,6 +334,10 @@ def is_malicious_model(file_path): def getModels(scan_for_malicious: bool = True): + from easydiffusion.backend_manager import backend + + backend.refresh_models() + models = { "options": { "stable-diffusion": [], @@ -329,19 +347,19 @@ def getModels(scan_for_malicious: bool = True): "codeformer": [{"codeformer": "CodeFormer"}], "embeddings": [], "controlnet": [ - {"control_v11p_sd15_canny": "Canny (*)"}, - {"control_v11p_sd15_openpose": "OpenPose (*)"}, - {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, - {"control_v11f1p_sd15_depth": "Depth (*)"}, - {"control_v11p_sd15_scribble": "Scribble"}, - {"control_v11p_sd15_softedge": "Soft Edge"}, - {"control_v11p_sd15_inpaint": "Inpaint"}, - {"control_v11p_sd15_lineart": "Line Art"}, - {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, - {"control_v11p_sd15_mlsd": "Straight Lines"}, - {"control_v11p_sd15_seg": "Segment"}, - {"control_v11e_sd15_shuffle": "Shuffle"}, - {"control_v11f1e_sd15_tile": "Tile"}, + # {"control_v11p_sd15_canny": "Canny (*)"}, + # {"control_v11p_sd15_openpose": "OpenPose (*)"}, + # {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, + # {"control_v11f1p_sd15_depth": "Depth (*)"}, + # {"control_v11p_sd15_scribble": "Scribble"}, + # {"control_v11p_sd15_softedge": "Soft Edge"}, + # {"control_v11p_sd15_inpaint": "Inpaint"}, + # {"control_v11p_sd15_lineart": "Line Art"}, + # {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, + # {"control_v11p_sd15_mlsd": "Straight Lines"}, + # {"control_v11p_sd15_seg": "Segment"}, + # {"control_v11e_sd15_shuffle": "Shuffle"}, + # {"control_v11f1e_sd15_tile": "Tile"}, ], }, } @@ -378,6 +396,8 @@ def getModels(scan_for_malicious: bool = True): model_id = entry.name[: -len(matching_suffix)] if callable(nameFilter): model_id = nameFilter(model_id) + if model_id is None: + continue model_exists = False for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models @@ -416,7 +436,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="stable-diffusion") listModels(model_type="vae") listModels(model_type="hypernetwork") - listModels(model_type="gfpgan") + listModels(model_type="gfpgan", nameFilter=lambda x: (x if "gfpgan" in x.lower() else None)) listModels(model_type="lora") listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="controlnet") diff --git a/ui/easydiffusion/runtime.py b/ui/easydiffusion/runtime.py index 78d90f60..accced00 100644 --- a/ui/easydiffusion/runtime.py +++ b/ui/easydiffusion/runtime.py @@ -1,4 +1,5 @@ """ +(OUTDATED DOC) A runtime that runs on a specific device (in a thread). It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context. @@ -6,42 +7,35 @@ It can run various tasks like image generation, image filtering, model merge etc This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function. """ -from easydiffusion import device_manager -from easydiffusion.utils import log -from sdkit import Context -from sdkit.utils import get_device_usage - -context = Context() # thread-local -""" -runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc -""" +context = None def init(device): """ Initializes the fields that will be bound to this runtime's context, and sets the current torch device """ + + global context + + from easydiffusion import device_manager + from easydiffusion.backend_manager import backend + from easydiffusion.app import getConfig + + context = backend.create_context() + context.stop_processing = False context.temp_images = {} context.partial_x_samples = None context.model_load_errors = {} context.enable_codeformer = True - from easydiffusion import app - - app_config = app.getConfig() - context.test_diffusers = app_config.get("use_v3_engine", True) - - log.info("Device usage during initialization:") - get_device_usage(device, log_info=True, process_usage_only=False) - device_manager.device_init(context, device) -def set_vram_optimizations(context: Context): - from easydiffusion import app +def set_vram_optimizations(context): + from easydiffusion.app import getConfig - config = app.getConfig() + config = getConfig() vram_usage_level = config.get("vram_usage_level", "balanced") if vram_usage_level != context.vram_usage_level: diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index a251ede6..ca7dc98e 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -2,6 +2,7 @@ Notes: async endpoints always run on the main thread. Without they run on the thread pool. """ + import datetime import mimetypes import os @@ -20,6 +21,7 @@ from easydiffusion.types import ( OutputFormatData, SaveToDiskData, convert_legacy_render_req_to_new, + convert_legacy_controlnet_filter_name, ) from easydiffusion.utils import log from fastapi import FastAPI, HTTPException @@ -67,6 +69,7 @@ class SetAppConfigRequest(BaseModel, extra=Extra.allow): listen_to_network: bool = None listen_port: int = None use_v3_engine: bool = True + backend: str = "ed_diffusers" models_dir: str = None @@ -155,6 +158,12 @@ def init(): def shutdown_event(): # Signal render thread to close on shutdown task_manager.current_state_error = SystemExit("Application shutting down.") + @server_api.on_event("startup") + def start_event(): + from easydiffusion.app import open_browser + + open_browser() + # API implementations def set_app_config_internal(req: SetAppConfigRequest): @@ -176,7 +185,8 @@ def set_app_config_internal(req: SetAppConfigRequest): config["net"] = {} config["net"]["listen_port"] = int(req.listen_port) - config["use_v3_engine"] = req.use_v3_engine + config["use_v3_engine"] = req.backend == "ed_diffusers" + config["backend"] = req.backend config["models_dir"] = req.models_dir for property, property_value in req.dict().items(): @@ -216,6 +226,8 @@ def read_web_data_internal(key: str = None, **kwargs): return JSONResponse(config, headers=NOCACHE_HEADERS) elif key == "system_info": + from easydiffusion.backend_manager import backend + config = app.getConfig() output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) @@ -226,6 +238,7 @@ def read_web_data_internal(key: str = None, **kwargs): "default_output_dir": output_dir, "enforce_output_dir": ("force_save_path" in config), "enforce_output_metadata": ("force_save_metadata" in config), + "backend_url": backend.get_url(), } system_info["devices"]["config"] = config.get("render_devices", "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) @@ -309,6 +322,15 @@ def filter_internal(req: dict): output_format: OutputFormatData = OutputFormatData.parse_obj(req) save_data: SaveToDiskData = SaveToDiskData.parse_obj(req) + filter_req.filter = convert_legacy_controlnet_filter_name(filter_req.filter) + + for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if models_data.model_paths.get(model_name): + if model_name not in filter_req.filter_params: + filter_req.filter_params[model_name] = {} + + filter_req.filter_params[model_name]["upscaler"] = models_data.model_paths[model_name] + # enqueue the task task = FilterTask(filter_req, task_data, models_data, output_format, save_data) return enqueue_task(task) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 699b4494..5ad6420d 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -4,6 +4,7 @@ Notes: Use weak_thread_data to store all other data using weak keys. This will allow for garbage collection after the thread dies. """ + import json import traceback @@ -19,7 +20,6 @@ import torch from easydiffusion import device_manager from easydiffusion.tasks import Task from easydiffusion.utils import log -from sdkit.utils import gc THREAD_NAME_PREFIX = "" ERR_LOCK_FAILED = " failed to acquire lock within timeout." @@ -233,6 +233,7 @@ def thread_render(device): global current_state, current_state_error from easydiffusion import model_manager, runtime + from easydiffusion.backend_manager import backend try: runtime.init(device) @@ -244,8 +245,17 @@ def thread_render(device): } current_state = ServerStates.LoadingModel - model_manager.load_default_models(runtime.context) + while True: + try: + if backend.ping(timeout=1): + break + + time.sleep(1) + except TimeoutError: + time.sleep(1) + + model_manager.load_default_models(runtime.context) current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) @@ -291,7 +301,6 @@ def thread_render(device): task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) finally: - gc(runtime.context) task.lock.release() keep_task_alive(task) diff --git a/ui/easydiffusion/tasks/filter_images.py b/ui/easydiffusion/tasks/filter_images.py index 7d3d1326..67572a8a 100644 --- a/ui/easydiffusion/tasks/filter_images.py +++ b/ui/easydiffusion/tasks/filter_images.py @@ -5,9 +5,7 @@ import time from numpy import base_repr -from sdkit.filter import apply_filters -from sdkit.models import load_model -from sdkit.utils import img_to_base64_str, get_image, log, save_images +from sdkit.utils import img_to_base64_str, log, save_images, base64_str_to_img from easydiffusion import model_manager, runtime from easydiffusion.types import ( @@ -19,6 +17,7 @@ from easydiffusion.types import ( TaskData, GenerateImageRequest, ) +from easydiffusion.utils import filter_nsfw from easydiffusion.utils.save_utils import format_folder_name from .task import Task @@ -47,7 +46,9 @@ class FilterTask(Task): # convert to multi-filter format, if necessary if isinstance(req.filter, str): - req.filter_params = {req.filter: req.filter_params} + if req.filter not in req.filter_params: + req.filter_params = {req.filter: req.filter_params} + req.filter = [req.filter] if not isinstance(req.image, list): @@ -57,6 +58,7 @@ class FilterTask(Task): "Runs the image filtering task on the assigned thread" from easydiffusion import app + from easydiffusion.backend_manager import backend context = runtime.context @@ -66,15 +68,24 @@ class FilterTask(Task): print_task_info(self.request, self.models_data, self.output_format, self.save_data) - if isinstance(self.request.image, list): - images = [get_image(img) for img in self.request.image] - else: - images = get_image(self.request.image) - - images = filter_images(context, images, self.request.filter, self.request.filter_params) + has_nsfw_filter = "nsfw_filter" in self.request.filter output_format = self.output_format + backend.set_options( + context, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + ) + + images = backend.filter_images( + context, self.request.image, self.request.filter, self.request.filter_params, input_type="base64" + ) + + if has_nsfw_filter: + images = filter_nsfw(images) + if self.save_data.save_to_disk_path is not None: app_config = app.getConfig() folder_format = app_config.get("folder_format", "$id") @@ -85,8 +96,9 @@ class FilterTask(Task): save_dir_path = os.path.join( self.save_data.save_to_disk_path, format_folder_name(folder_format, dummy_req, self.task_data) ) + images_pil = [base64_str_to_img(img) for img in images] save_images( - images, + images_pil, save_dir_path, file_name=img_id, output_format=output_format.output_format, @@ -94,13 +106,6 @@ class FilterTask(Task): output_lossless=output_format.output_lossless, ) - images = [ - img_to_base64_str( - img, output_format.output_format, output_format.output_quality, output_format.output_lossless - ) - for img in images - ] - res = FilterImageResponse(self.request, self.models_data, images=images) res = res.json() self.buffer_queue.put(json.dumps(res)) @@ -110,46 +115,6 @@ class FilterTask(Task): self.response = res -def filter_images(context, images, filters, filter_params={}): - filters = filters if isinstance(filters, list) else [filters] - - for filter_name in filters: - params = filter_params.get(filter_name, {}) - - previous_state = before_filter(context, filter_name, params) - - try: - images = apply_filters(context, filter_name, images, **params) - finally: - after_filter(context, filter_name, params, previous_state) - - return images - - -def before_filter(context, filter_name, filter_params): - if filter_name == "codeformer": - from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use - - default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] - prev_realesrgan_path = None - - upscale_faces = filter_params.get("upscale_faces", False) - if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: - prev_realesrgan_path = context.model_paths.get("realesrgan") - context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") - load_model(context, "realesrgan") - - return prev_realesrgan_path - - -def after_filter(context, filter_name, filter_params, previous_state): - if filter_name == "codeformer": - prev_realesrgan_path = previous_state - if prev_realesrgan_path: - context.model_paths["realesrgan"] = prev_realesrgan_path - load_model(context, "realesrgan") - - def print_task_info( req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData ): diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index 7d8fbc5e..2a46e6e6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -2,26 +2,23 @@ import json import pprint import queue import time +from PIL import Image from easydiffusion import model_manager, runtime from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData, SaveToDiskData from easydiffusion.types import Image as ResponseImage -from easydiffusion.types import GenerateImageResponse, RenderTaskData, UserInitiatedStop -from easydiffusion.utils import get_printable_request, log, save_images_to_disk -from sdkit.generate import generate_images +from easydiffusion.types import GenerateImageResponse, RenderTaskData +from easydiffusion.utils import get_printable_request, log, save_images_to_disk, filter_nsfw from sdkit.utils import ( - diffusers_latent_samples_to_images, - gc, img_to_base64_str, + base64_str_to_img, img_to_buffer, - latent_samples_to_images, resize_img, get_image, log, ) from .task import Task -from .filter_images import filter_images class RenderTask(Task): @@ -51,15 +48,13 @@ class RenderTask(Task): "Runs the image generation task on the assigned thread" from easydiffusion import task_manager, app + from easydiffusion.backend_manager import backend context = runtime.context config = app.getConfig() if config.get("block_nsfw", False): # override if set on the server self.task_data.block_nsfw = True - if "nsfw_checker" not in self.task_data.filters: - self.task_data.filters.append("nsfw_checker") - self.models_data.model_paths["nsfw_checker"] = "nsfw_checker" def step_callback(): task_manager.keep_task_alive(self) @@ -68,7 +63,7 @@ class RenderTask(Task): if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance( self.error, StopAsyncIteration ): - context.stop_processing = True + backend.stop_rendering(context) if isinstance(task_manager.current_state_error, StopAsyncIteration): self.error = task_manager.current_state_error task_manager.current_state_error = None @@ -78,11 +73,7 @@ class RenderTask(Task): model_manager.resolve_model_paths(self.models_data) models_to_force_reload = [] - if ( - runtime.set_vram_optimizations(context) - or self.has_param_changed(context, "clip_skip") - or self.trt_needs_reload(context) - ): + if runtime.set_vram_optimizations(context) or self.has_param_changed(context, "clip_skip"): models_to_force_reload.append("stable-diffusion") model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload) @@ -99,10 +90,11 @@ class RenderTask(Task): self.buffer_queue, self.temp_images, step_callback, + self, ) def has_param_changed(self, context, param_name): - if not context.test_diffusers: + if not getattr(context, "test_diffusers", False): return False if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: return True @@ -111,29 +103,6 @@ class RenderTask(Task): new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) return model["params"].get(param_name) != new_val - def trt_needs_reload(self, context): - if not context.test_diffusers: - return False - if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: - return True - - model = context.models["stable-diffusion"] - - # curr_convert_to_trt = model["params"].get("convert_to_tensorrt") - new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False) - - pipe = model["default"] - is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr( - pipe.unet, "_allocate_trt_buffers_backup" - ) - if new_convert_to_trt and not is_trt_loaded: - return True - - curr_build_config = model["params"].get("trt_build_config") - new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {}) - - return new_convert_to_trt and curr_build_config != new_build_config - def make_images( context, @@ -145,12 +114,21 @@ def make_images( data_queue: queue.Queue, task_temp_images: list, step_callback, + task, ): - context.stop_processing = False print_task_info(req, task_data, models_data, output_format, save_data) images, seeds = make_images_internal( - context, req, task_data, models_data, output_format, save_data, data_queue, task_temp_images, step_callback + context, + req, + task_data, + models_data, + output_format, + save_data, + data_queue, + task_temp_images, + step_callback, + task, ) res = GenerateImageResponse( @@ -170,7 +148,9 @@ def print_task_info( output_format: OutputFormatData, save_data: SaveToDiskData, ): - req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[") + req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace( + "[", "\[" + ) task_str = pprint.pformat(task_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[") @@ -178,7 +158,7 @@ def print_task_info( log.info(f"request: {req_str}") log.info(f"task data: {task_str}") - # log.info(f"models data: {models_data}") + log.info(f"models data: {models_data}") log.info(f"output format: {output_format}") log.info(f"save data: {save_data}") @@ -193,26 +173,41 @@ def make_images_internal( data_queue: queue.Queue, task_temp_images: list, step_callback, + task, ): - images, user_stopped = generate_images_internal( + from easydiffusion.backend_manager import backend + + # prep the nsfw_filter + if task_data.block_nsfw: + filter_nsfw([Image.new("RGB", (1, 1))]) # hack - ensures that the model is available + + images = generate_images_internal( context, req, task_data, models_data, + output_format, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval, ) - - gc(context) + user_stopped = isinstance(task.error, StopAsyncIteration) filters, filter_params = task_data.filters, task_data.filter_params - filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images + if len(filters) > 0 and not user_stopped: + filtered_images = backend.filter_images(context, images, filters, filter_params, input_type="base64") + else: + filtered_images = images + + if task_data.block_nsfw: + filtered_images = filter_nsfw(filtered_images) if save_data.save_to_disk_path is not None: - save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data) + images_pil = [base64_str_to_img(img) for img in images] + filtered_images_pil = [base64_str_to_img(img) for img in filtered_images] + save_images_to_disk(images_pil, filtered_images_pil, req, task_data, models_data, output_format, save_data) seeds = [*range(req.seed, req.seed + len(images))] if task_data.show_only_filtered_image or filtered_images is images: @@ -226,97 +221,43 @@ def generate_images_internal( req: GenerateImageRequest, task_data: RenderTaskData, models_data: ModelsData, + output_format: OutputFormatData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool, stream_image_progress_interval: int, ): - context.temp_images.clear() + from easydiffusion.backend_manager import backend - callback = make_step_callback( + callback = make_step_callback(context, req, task_data, data_queue, task_temp_images, step_callback) + + req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 + + if req.control_image and task_data.control_filter_to_apply: + req.controlnet_filter = task_data.control_filter_to_apply + + if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: + req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 + + backend.set_options( context, - req, - task_data, - data_queue, - task_temp_images, - step_callback, - stream_image_progress, - stream_image_progress_interval, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + vae_tiling=task_data.enable_vae_tiling, + stream_image_progress=stream_image_progress, + stream_image_progress_interval=stream_image_progress_interval, + clip_skip=2 if task_data.clip_skip else 1, ) - try: - if req.init_image is not None and not context.test_diffusers: - req.sampler_name = "ddim" + images = backend.generate_images(context, callback=callback, output_type="base64", **req.dict()) - req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 - - if req.control_image and task_data.control_filter_to_apply: - req.control_image = get_image(req.control_image) - req.control_image = resize_img(req.control_image.convert("RGB"), req.width, req.height, clamp_to_8=True) - req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0] - - if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: - req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 - - if context.test_diffusers: - pipe = context.models["stable-diffusion"]["default"] - if hasattr(pipe.unet, "_allocate_trt_buffers_backup"): - setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup) - delattr(pipe.unet, "_allocate_trt_buffers_backup") - - if hasattr(pipe.unet, "_allocate_trt_buffers"): - convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False) - if convert_to_trt: - pipe.unet.forward = pipe.unet._trt_forward - # pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward - log.info(f"Setting unet.forward to TensorRT") - else: - log.info(f"Not using TensorRT for unet.forward") - pipe.unet.forward = pipe.unet._non_trt_forward - # pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward - setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers) - delattr(pipe.unet, "_allocate_trt_buffers") - - if task_data.enable_vae_tiling: - if hasattr(pipe, "enable_vae_tiling"): - pipe.enable_vae_tiling() - else: - if hasattr(pipe, "disable_vae_tiling"): - pipe.disable_vae_tiling() - - images = generate_images(context, callback=callback, **req.dict()) - user_stopped = False - except UserInitiatedStop: - images = [] - user_stopped = True - if context.partial_x_samples is not None: - if context.test_diffusers: - images = diffusers_latent_samples_to_images(context, context.partial_x_samples) - else: - images = latent_samples_to_images(context, context.partial_x_samples) - finally: - if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: - if not context.test_diffusers: - del context.partial_x_samples - context.partial_x_samples = None - - return images, user_stopped + return images def construct_response(images: list, seeds: list, output_format: OutputFormatData): - return [ - ResponseImage( - data=img_to_base64_str( - img, - output_format.output_format, - output_format.output_quality, - output_format.output_lossless, - ), - seed=seed, - ) - for img, seed in zip(images, seeds) - ] + return [ResponseImage(data=img, seed=seed) for img, seed in zip(images, seeds)] def make_step_callback( @@ -326,53 +267,44 @@ def make_step_callback( data_queue: queue.Queue, task_temp_images: list, step_callback, - stream_image_progress: bool, - stream_image_progress_interval: int, ): + from easydiffusion.backend_manager import backend + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 - def update_temp_img(x_samples, task_temp_images: list): + def update_temp_img(images, task_temp_images: list): partial_images = [] - if context.test_diffusers: - images = diffusers_latent_samples_to_images(context, x_samples) - else: - images = latent_samples_to_images(context, x_samples) + if images is None: + return [] if task_data.block_nsfw: - images = filter_images(context, images, "nsfw_checker") + images = filter_nsfw(images, print_log=False) for i, img in enumerate(images): + img = img.convert("RGB") + img = resize_img(img, req.width, req.height) buf = img_to_buffer(img, output_format="JPEG") - context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"}) del images return partial_images - def on_image_step(x_samples, i, *args): + def on_image_step(images, i, *args): nonlocal last_callback_time - if context.test_diffusers: - context.partial_x_samples = (x_samples, args[0]) - else: - context.partial_x_samples = x_samples - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() progress = {"step": i, "step_time": step_time, "total_steps": n_steps} - if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: - progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images) + if images is not None: + progress["output"] = update_temp_img(images, task_temp_images) data_queue.put(json.dumps(progress)) step_callback() - if context.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") - return on_image_step diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index f16e9aa3..93bfe08f 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -19,9 +19,10 @@ class GenerateImageRequest(BaseModel): init_image_mask: Any = None control_image: Any = None control_alpha: Union[float, List[float]] = None + controlnet_filter: str = None prompt_strength: float = 0.8 - preserve_init_image_color_profile = False - strict_mask_border = False + preserve_init_image_color_profile: bool = False + strict_mask_border: bool = False sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 @@ -100,7 +101,7 @@ class MergeRequest(BaseModel): model1: str = None ratio: float = None out_path: str = "mix" - use_fp16 = True + use_fp16: bool = True class Image: @@ -213,22 +214,19 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["controlnet"] = old_req.get("use_controlnet_model") model_paths["embeddings"] = old_req.get("use_embeddings_model") - model_paths["gfpgan"] = old_req.get("use_face_correction", "") - model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None + ## ensure that the model name is in the model path + for model_name in ("gfpgan", "codeformer"): + model_paths[model_name] = old_req.get("use_face_correction", "") + model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None - model_paths["codeformer"] = old_req.get("use_face_correction", "") - model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None + for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + model_paths[model_name] = old_req.get("use_upscale", "") + model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None - model_paths["realesrgan"] = old_req.get("use_upscale", "") - model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None - - model_paths["latent_upscaler"] = old_req.get("use_upscale", "") - model_paths["latent_upscaler"] = ( - model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None - ) if "control_filter_to_apply" in old_req: filter_model = old_req["control_filter_to_apply"] model_paths[filter_model] = filter_model + old_req["control_filter_to_apply"] = convert_legacy_controlnet_filter_name(old_req["control_filter_to_apply"]) if old_req.get("block_nsfw"): model_paths["nsfw_checker"] = "nsfw_checker" @@ -244,8 +242,12 @@ def convert_legacy_render_req_to_new(old_req: dict): } # move the filter params - if model_paths["realesrgan"]: - filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))} + for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if model_paths[model_name]: + filter_params[model_name] = { + "upscaler": model_paths[model_name], + "scale": int(old_req.get("upscale_amount", 4)), + } if model_paths["latent_upscaler"]: filter_params["latent_upscaler"] = { "prompt": old_req["prompt"], @@ -264,14 +266,31 @@ def convert_legacy_render_req_to_new(old_req: dict): if old_req.get("block_nsfw"): filters.append("nsfw_checker") - if model_paths["codeformer"]: - filters.append("codeformer") - elif model_paths["gfpgan"]: - filters.append("gfpgan") + for model_name in ("gfpgan", "codeformer"): + if model_paths[model_name]: + filters.append(model_name) + break - if model_paths["realesrgan"]: - filters.append("realesrgan") - elif model_paths["latent_upscaler"]: - filters.append("latent_upscaler") + for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if model_paths[model_name]: + filters.append(model_name) + break return new_req + + +def convert_legacy_controlnet_filter_name(filter): + from easydiffusion.backend_manager import backend + + if filter is None: + return None + + controlnet_filter_names = backend.list_controlnet_filters() + + def apply(f): + return f"controlnet_{f}" if f in controlnet_filter_names else f + + if isinstance(filter, list): + return [apply(f) for f in filter] + + return apply(filter) diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py index a930725b..3be08e79 100644 --- a/ui/easydiffusion/utils/__init__.py +++ b/ui/easydiffusion/utils/__init__.py @@ -7,6 +7,8 @@ from .save_utils import ( save_images_to_disk, get_printable_request, ) +from .nsfw_checker import filter_nsfw + def sha256sum(filename): sha256 = hashlib.sha256() @@ -18,4 +20,3 @@ def sha256sum(filename): sha256.update(data) return sha256.hexdigest() - diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py new file mode 100644 index 00000000..3790cacc --- /dev/null +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -0,0 +1,80 @@ +# possibly move this to sdkit in the future +import os + +# mirror of https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/model_quantized.onnx +NSFW_MODEL_URL = ( + "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/vit-base-nsfw-detector-quantized.onnx" +) +MODEL_HASH_QUICK = "220123559305b1b07b7a0894c3471e34dccd090d71cdf337dd8012f9e40d6c28" + +nsfw_check_model = None + + +def filter_nsfw(images, blur_radius: float = 75, print_log=True): + global nsfw_check_model + + from easydiffusion.app import MODELS_DIR + from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick + + import onnxruntime as ort + from PIL import ImageFilter + import numpy as np + + if nsfw_check_model is None: + model_dir = os.path.join(MODELS_DIR, "nsfw-checker") + model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") + + os.makedirs(model_dir, exist_ok=True) + + if not os.path.exists(model_path) or hash_file_quick(model_path) != MODEL_HASH_QUICK: + download_file(NSFW_MODEL_URL, model_path) + + nsfw_check_model = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + # Preprocess the input image + def preprocess_image(img): + img = img.convert("RGB") + + # config based on based on https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/preprocessor_config.json + # Resize the image + img = img.resize((384, 384)) + + # Normalize the image + img = np.array(img) / 255.0 # Scale pixel values to [0, 1] + mean = np.array([0.5, 0.5, 0.5]) + std = np.array([0.5, 0.5, 0.5]) + img = (img - mean) / std + + # Transpose to match input shape (batch_size, channels, height, width) + img = np.transpose(img, (2, 0, 1)).astype(np.float32) + + # Add batch dimension + img = np.expand_dims(img, axis=0) + + return img + + # Run inference + input_name = nsfw_check_model.get_inputs()[0].name + output_name = nsfw_check_model.get_outputs()[0].name + + if print_log: + log.info("Running NSFW checker (onnx)") + + results = [] + for img in images: + is_base64 = isinstance(img, str) + + input_img = base64_str_to_img(img) if is_base64 else img + + result = nsfw_check_model.run([output_name], {input_name: preprocess_image(input_img)}) + is_nsfw = [np.argmax(arr) == 1 for arr in result][0] + + if is_nsfw: + output_img = input_img.filter(ImageFilter.GaussianBlur(blur_radius)) + output_img = img_to_base64_str(output_img) if is_base64 else output_img + else: + output_img = img + + results.append(output_img) + + return results diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index e1b4d79a..216ec899 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -247,7 +247,7 @@ def get_printable_request( task_data_metadata.update(save_data.dict()) app_config = app.getConfig() - using_diffusers = app_config.get("use_v3_engine", True) + using_diffusers = app_config.get("backend", "ed_diffusers") in ("ed_diffusers", "webui") # Save the metadata in the order defined in TASK_TEXT_MAPPING metadata = {} diff --git a/ui/index.html b/ui/index.html index a50c995d..2fd11a5c 100644 --- a/ui/index.html +++ b/ui/index.html @@ -35,7 +35,10 @@

Easy Diffusion - v3.0.9 + + v3.5.0 + +

@@ -73,7 +76,7 @@
- +
@@ -83,7 +86,7 @@ Click to learn more about Negative Prompts (optional) - +
@@ -174,14 +177,14 @@ - + Click to learn more about Clip Skip - +
@@ -201,40 +204,92 @@ + + + + + + + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + +
-
- +
@@ -248,27 +303,38 @@ - Click to learn more about samplers + Click to learn more about samplers
- + @@ -357,14 +423,14 @@
- + - +
- + - + @@ -405,7 +471,7 @@
- +
  • @@ -418,7 +484,13 @@ @@ -825,7 +897,8 @@

    This license of this software forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm,
    spread misinformation and target vulnerable groups. For the full list of restrictions please read the license.

    By using this software, you consent to the terms and conditions of the license.

    - + + diff --git a/ui/main.py b/ui/main.py index 0239c829..1c4056a3 100644 --- a/ui/main.py +++ b/ui/main.py @@ -9,6 +9,3 @@ server.init() model_manager.init() app.init_render_threads() bucket_manager.init() - -# start the browser ui -app.open_browser() diff --git a/ui/media/css/auto-save.css b/ui/media/css/auto-save.css index 94790740..021f06d5 100644 --- a/ui/media/css/auto-save.css +++ b/ui/media/css/auto-save.css @@ -79,6 +79,7 @@ } .parameters-table .fa-fire, -.parameters-table .fa-bolt { +.parameters-table .fa-bolt, +.parameters-table .fa-robot { color: #F7630C; } diff --git a/ui/media/css/main.css b/ui/media/css/main.css index 009c13d5..c5e862fe 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -36,6 +36,15 @@ code { transform: translateY(4px); cursor: pointer; } +#engine-logo { + font-size: 8pt; + padding-left: 10pt; + color: var(--small-label-color); +} +#engine-logo a { + text-decoration: none; + /* color: var(--small-label-color); */ +} #prompt { width: 100%; height: 65pt; @@ -541,7 +550,7 @@ div.img-preview img { position: relative; background: var(--background-color4); display: flex; - padding: 12px 0 0; + padding: 6px 0 0; } .tab .icon { padding-right: 4pt; @@ -657,6 +666,10 @@ div.img-preview img { display: block; } +.gated-feature { + display: none; +} + .display-settings { float: right; position: relative; diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 1cbd3288..2b90a6d3 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -981,7 +981,20 @@ function onRedoFilter(req, img, e, tools) { function onUpscaleClick(req, img, e, tools) { let path = upscaleModelField.value let scale = parseInt(upscaleAmountField.value) - let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler" + + let filterName = null + const FILTERS = ["realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"] + for (let idx in FILTERS) { + let f = FILTERS[idx] + if (path.toLowerCase().includes(f)) { + filterName = f + break + } + } + + if (!filterName) { + return + } let statusText = "Upscaling by " + scale + "x using " + filterName applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools) } diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 97b7a96d..a597b281 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -249,14 +249,19 @@ var PARAMETERS = [ default: false, }, { - id: "use_v3_engine", - type: ParameterType.checkbox, - label: "Use the new v3 engine (diffusers)", + id: "backend", + type: ParameterType.select, + label: "Engine to use", note: - "Use our new v3 engine, with additional features like LoRA, ControlNet, SDXL, Embeddings, Tiling and lots more! Please press Save, then restart the program after changing this.", - icon: "fa-bolt", - default: true, + "Use our new v3.5 engine (Forge), with additional features like Flux, SD3, Lycoris and lots more! Please press Save, then restart the program after changing this.", + icon: "fa-robot", saveInAppConfig: true, + default: "ed_diffusers", + options: [ + { value: "webui", label: "v3.5 (latest)" }, + { value: "ed_diffusers", label: "v3.0" }, + { value: "ed_classic", label: "v2.0" }, + ], }, { id: "cloudflare", @@ -432,6 +437,7 @@ let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") let testDiffusers = document.querySelector("#use_v3_engine") +let backendEngine = document.querySelector("#backend") let profileNameField = document.querySelector("#profileName") let modelsDirField = document.querySelector("#models_dir") @@ -454,6 +460,23 @@ async function changeAppConfig(configDelta) { } } +function getDefaultDisplay(element) { + const tag = element.tagName.toLowerCase(); + const defaultDisplays = { + div: 'block', + span: 'inline', + p: 'block', + tr: 'table-row', + table: 'table', + li: 'list-item', + ul: 'block', + ol: 'block', + button: 'inline', + // Add more if needed + }; + return defaultDisplays[tag] || 'block'; // Default to 'block' if not listed +} + async function getAppConfig() { try { let res = await fetch("/get/app_config") @@ -478,14 +501,16 @@ async function getAppConfig() { modelsDirField.value = config.models_dir let testDiffusersEnabled = true - if (config.use_v3_engine === false) { + if (config.backend === "ed_classic") { testDiffusersEnabled = false } testDiffusers.checked = testDiffusersEnabled + backendEngine.value = config.backend document.querySelector("#test_diffusers").checked = testDiffusers.checked // don't break plugins + document.querySelector("#use_v3_engine").checked = testDiffusers.checked // don't break plugins if (config.config_on_startup) { - if (config.config_on_startup?.use_v3_engine) { + if (config.config_on_startup?.backend !== "ed_classic") { document.body.classList.add("diffusers-enabled-on-startup") document.body.classList.remove("diffusers-disabled-on-startup") } else { @@ -494,37 +519,27 @@ async function getAppConfig() { } } - if (!testDiffusersEnabled) { - document.querySelector("#lora_model_container").style.display = "none" - document.querySelector("#tiling_container").style.display = "none" - document.querySelector("#controlnet_model_container").style.display = "none" - document.querySelector("#hypernetwork_model_container").style.display = "" - document.querySelector("#hypernetwork_strength_container").style.display = "" - document.querySelector("#negative-embeddings-button").style.display = "none" - - document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => { - option.style.display = "none" - }) + if (config.backend === "ed_classic") { IMAGE_STEP_SIZE = 64 - customWidthField.step = IMAGE_STEP_SIZE - customHeightField.step = IMAGE_STEP_SIZE } else { - document.querySelector("#lora_model_container").style.display = "" - document.querySelector("#tiling_container").style.display = "" - document.querySelector("#controlnet_model_container").style.display = "" - document.querySelector("#hypernetwork_model_container").style.display = "none" - document.querySelector("#hypernetwork_strength_container").style.display = "none" - - document.querySelectorAll("#sampler_name option.k_diffusion-only").forEach((option) => { - option.style.display = "none" - }) - document.querySelector("#clip_skip_config").classList.remove("displayNone") - document.querySelector("#embeddings-button").classList.remove("displayNone") IMAGE_STEP_SIZE = 8 - customWidthField.step = IMAGE_STEP_SIZE - customHeightField.step = IMAGE_STEP_SIZE } + customWidthField.step = IMAGE_STEP_SIZE + customHeightField.step = IMAGE_STEP_SIZE + + const currentBackendKey = "backend_" + config.backend + + document.querySelectorAll('.gated-feature').forEach((element) => { + const featureKeys = element.getAttribute('data-feature-keys').split(' ') + + if (featureKeys.includes(currentBackendKey)) { + element.style.display = getDefaultDisplay(element) + } else { + element.style.display = 'none' + } + }); + if (config.force_save_metadata) { metadataOutputFormatField.value = config.force_save_metadata } @@ -749,6 +764,11 @@ async function getSystemInfo() { metadataOutputFormatField.disabled = !saveToDiskField.checked } setDiskPath(res["default_output_dir"], force) + + // backend info + if (res["backend_url"]) { + document.querySelector("#backend-url").setAttribute("href", res["backend_url"]) + } } catch (e) { console.log("error fetching devices", e) } From 754a5f5e52511365aacd0e04c88e94a956fe30f7 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Oct 2024 13:55:35 +0530 Subject: [PATCH 08/64] Case-insensitive model directories --- ui/easydiffusion/backends/webui/__init__.py | 27 ++++++++++++- ui/easydiffusion/model_manager.py | 42 ++++++++++++++++++--- ui/easydiffusion/server.py | 8 ++-- ui/easydiffusion/utils/nsfw_checker.py | 4 +- ui/media/js/parameters.js | 2 +- 5 files changed, 67 insertions(+), 16 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index f78d1164..abeece38 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -6,6 +6,7 @@ from threading import local import psutil from easydiffusion.app import ROOT_DIR, getConfig +from easydiffusion.model_manager import get_model_dir from . import impl from .impl import ( @@ -32,6 +33,18 @@ BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") +MODELS_TO_OVERRIDE = { + "stable-diffusion": "--ckpt-dir", + "vae": "--vae-dir", + "hypernetwork": "--hypernetwork-dir", + "gfpgan": "--gfpgan-models-path", + "realesrgan": "--realesrgan-models-path", + "lora": "--lora-dir", + "codeformer": "--codeformer-models-path", + "embeddings": "--embeddings-dir", + "controlnet": "--controlnet-dir", +} + backend_process = None @@ -104,7 +117,8 @@ def get_env(): config = getConfig() models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) - embeddings_dir = os.path.join(models_dir, "embeddings") + + model_path_args = get_model_path_args() env_entries = { "PATH": [ @@ -125,7 +139,7 @@ def get_env(): "PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"], "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], - "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'], + "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args}'], "SKIP_VENV": ["1"], "SD_WEBUI_RESTARTING": ["1"], "PYTHON": [f"{dir}/python/python"], @@ -153,3 +167,12 @@ def kill(proc_pid): for proc in process.children(recursive=True): proc.kill() process.kill() + + +def get_model_path_args(): + args = [] + for model_type, flag in MODELS_TO_OVERRIDE.items(): + model_dir = get_model_dir(model_type) + args.append(f'{flag} "{model_dir}"') + + return " ".join(args) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index d821db41..2a20fce9 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -122,7 +122,7 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dir = os.path.join(app.MODELS_DIR, model_type) + model_dir = get_model_dir(model_type) if not model_name: # When None try user configured model. # config = getConfig() if "model" in config and model_type in config["model"]: @@ -239,7 +239,8 @@ def download_default_models_if_necessary(): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): - model_path = os.path.join(app.MODELS_DIR, model_type, file_name) + model_dir = get_model_dir(model_type) + model_path = os.path.join(model_dir, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] other_models_exist = any_model_exists(model_type) and skip_if_others_exist @@ -259,13 +260,15 @@ def migrate_legacy_model_location(): file_name = model["file_name"] legacy_path = os.path.join(app.SD_DIR, file_name) if os.path.exists(legacy_path): - shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name)) + model_dir = get_model_dir(model_type) + shutil.move(legacy_path, os.path.join(model_dir, file_name)) def any_model_exists(model_type: str) -> bool: extensions = MODEL_EXTENSIONS.get(model_type, []) + model_dir = get_model_dir(model_type) for ext in extensions: - if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)): + if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): return True return False @@ -273,7 +276,7 @@ def any_model_exists(model_type: str) -> bool: def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: - model_dir_path = os.path.join(app.MODELS_DIR, model_type) + model_dir_path = get_model_dir(model_type) try: os.makedirs(model_dir_path, exist_ok=True) @@ -418,7 +421,7 @@ def getModels(scan_for_malicious: bool = True): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) - models_dir = os.path.join(app.MODELS_DIR, model_type) + models_dir = get_model_dir(model_type) if not os.path.exists(models_dir): os.makedirs(models_dir) @@ -445,3 +448,30 @@ def getModels(scan_for_malicious: bool = True): log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") return models + + +def get_model_dir(model_type: str, base_dir=None): + "Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)" + + if base_dir is None: + base_dir = app.MODELS_DIR + + for dir in os.listdir(base_dir): + if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)): + return os.path.join(base_dir, dir) + + return os.path.join(base_dir, model_type) + + +# patch sdkit +def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type): + "Patched version that works with case-insensitive model sub-dirs" + + download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir + download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir + return os.path.abspath(download_base_dir) + + +from sdkit.models import model_downloader + +model_downloader.get_actual_base_dir = __patched__get_actual_base_dir diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index ca7dc98e..f1b85764 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -364,15 +364,13 @@ def model_merge_internal(req: dict): mergeReq: MergeRequest = MergeRequest.parse_obj(req) + sd_model_dir = model_manager.get_model_dir("stable-diffusion") + merge_models( model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), mergeReq.ratio, - os.path.join( - app.MODELS_DIR, - "stable-diffusion", - filename_regex.sub("_", mergeReq.out_path), - ), + os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)), mergeReq.use_fp16, ) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py index 3790cacc..9e371a37 100644 --- a/ui/easydiffusion/utils/nsfw_checker.py +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -13,7 +13,7 @@ nsfw_check_model = None def filter_nsfw(images, blur_radius: float = 75, print_log=True): global nsfw_check_model - from easydiffusion.app import MODELS_DIR + from easydiffusion.model_manager import get_model_dir from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick import onnxruntime as ort @@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True): import numpy as np if nsfw_check_model is None: - model_dir = os.path.join(MODELS_DIR, "nsfw-checker") + model_dir = get_model_dir("nsfw-checker") model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") os.makedirs(model_dir, exist_ok=True) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index a597b281..c8624f40 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -102,7 +102,7 @@ var PARAMETERS = [ type: ParameterType.custom, icon: "fa-folder-tree", label: "Models Folder", - note: "Path to the 'models' folder. Please save and refresh the page after changing this.", + note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.", saveInAppConfig: true, render: (parameter) => { return `` From f51ab909ff58ec74c283fc476adddf199f24bc3d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Oct 2024 16:25:24 +0530 Subject: [PATCH 09/64] Reset VAE upon restart --- ui/easydiffusion/backends/webui/impl.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py index c90970ec..393f8ea9 100644 --- a/ui/easydiffusion/backends/webui/impl.py +++ b/ui/easydiffusion/backends/webui/impl.py @@ -13,7 +13,12 @@ from sdkit.utils import base64_str_to_img, img_to_base64_str WEBUI_HOST = "localhost" WEBUI_PORT = "7860" -DEFAULT_WEBUI_OPTIONS = {"show_progress_every_n_steps": 3, "show_progress_grid": True, "live_previews_enable": False} +DEFAULT_WEBUI_OPTIONS = { + "show_progress_every_n_steps": 3, + "show_progress_grid": True, + "live_previews_enable": False, + "forge_additional_modules": [], +} webui_opts: dict = None @@ -449,6 +454,8 @@ def image_progress_thread(task_id, callback, stream_image_progress, total_images ) if res.status_code == 200: res = res.json() + else: + raise RuntimeError(f"Unexpected progress response. Status code: {res.status_code}. Res: {res.text}") last_preview_id = res["id_live_preview"] From 4e3a5cb6d9caa3f3c25b3d03edd36f073e5ed28a Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Oct 2024 16:25:56 +0530 Subject: [PATCH 10/64] Automatically restart webui if it stops/crashes --- ui/easydiffusion/backends/webui/__init__.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index abeece38..94d9ff03 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -69,8 +69,13 @@ def start_backend(): global backend_process cmd = "webui.bat" if platform.system() == "Windows" else "webui.sh" - print("starting", cmd, WEBUI_DIR) - backend_process = subprocess.Popen([cmd], shell=True, cwd=WEBUI_DIR, env=env) + + while True: + print("starting", cmd, WEBUI_DIR) + backend_process = subprocess.Popen([cmd], shell=True, cwd=WEBUI_DIR, env=env) + backend_process.wait() + + stop_backend() backend_thread = threading.Thread(target=target) backend_thread.start() @@ -80,7 +85,10 @@ def stop_backend(): global backend_process if backend_process: - kill(backend_process.pid) + try: + kill(backend_process.pid) + except: + pass backend_process = None From 391e12e20d1f04c7159785af9ba807d5e6d11233 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Oct 2024 16:41:26 +0530 Subject: [PATCH 11/64] Require onnxruntime for nsfw checking --- scripts/check_modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 9a067e68..b3ecfd07 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -34,6 +34,7 @@ modules_to_check = { "sqlalchemy": "2.0.19", "python-multipart": "0.0.6", # "xformers": "0.0.16", + "onnxruntime": "1.19.2", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"] From 6ea7dd36dad66dba19929916d68862a1f3954a9f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 2 Oct 2024 15:30:29 +0530 Subject: [PATCH 12/64] Fix a bug where the config file wasn't actually read on linux/mac --- scripts/check_modules.py | 2 +- scripts/get_config.py | 1 + ui/easydiffusion/app.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index b3ecfd07..0d58d43d 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -298,7 +298,7 @@ Thanks!""" def get_config(): config_directory = os.path.dirname(__file__) # this will be "scripts" - config_yaml = os.path.join(config_directory, "..", "config.yaml") + config_yaml = os.path.abspath(os.path.join(config_directory, "..", "config.yaml")) config_json = os.path.join(config_directory, "config.json") config = None diff --git a/scripts/get_config.py b/scripts/get_config.py index 0bcc90a1..58f0e9ff 100644 --- a/scripts/get_config.py +++ b/scripts/get_config.py @@ -6,6 +6,7 @@ import shutil # The config file is in the same directory as this script config_directory = os.path.dirname(__file__) config_yaml = os.path.join(config_directory, "..", "config.yaml") +config_yaml = os.path.abspath(config_yaml) config_json = os.path.join(config_directory, "config.json") parser = argparse.ArgumentParser(description='Get values from config file') diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index ec856991..c570c773 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -119,6 +119,7 @@ def init_render_threads(): def getConfig(default_val=APP_CONFIG_DEFAULTS): config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml") + config_yaml_path = os.path.abspath(config_yaml_path) # migrate the old config yaml location config_legacy_yaml = os.path.join(CONFIG_DIR, "config.yaml") From 6f4e2017f4f4f665b559242a138dcf8cb21a40bb Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 7 Oct 2024 11:28:28 +0530 Subject: [PATCH 13/64] Install Forge automatically by creating a conda environment and cloing the forge repo --- ui/easydiffusion/app.py | 33 +++- ui/easydiffusion/backends/ed_classic.py | 1 + ui/easydiffusion/backends/ed_diffusers.py | 1 + ui/easydiffusion/backends/sdkit_common.py | 4 + ui/easydiffusion/backends/webui/__init__.py | 194 ++++++++++++++++++-- ui/easydiffusion/task_manager.py | 3 +- 6 files changed, 207 insertions(+), 29 deletions(-) diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index c570c773..6b1d2a09 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -331,16 +331,31 @@ def open_browser(): webbrowser.open(f"http://localhost:{port}") - Console().print( - Panel( - "\n" - + "[white]Easy Diffusion is ready to serve requests.\n\n" - + "A new browser tab should have been opened by now.\n" - + f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", - title="Easy Diffusion is ready", - style="bold yellow on blue", + from easydiffusion.backend_manager import backend + + if backend.is_installed(): + Console().print( + Panel( + "\n" + + "[white]Easy Diffusion is ready to serve requests.\n\n" + + "A new browser tab should have been opened by now.\n" + + f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", + title="Easy Diffusion is ready", + style="bold yellow on blue", + ) + ) + else: + backend_name = config["backend"] + Console().print( + Panel( + "\n" + + f"[white]Backend: {backend_name} is still installing..\n\n" + + "Please wait until it finishes installing before making images.\n" + + f"The UI will turn green in the top-right corner once it is ready (in [bold yellow underline]http://localhost:{port}[/]).\n", + title=f"Backend engine is installing", + style="bold yellow on blue", + ) ) - ) def fail_and_die(fail_type: str, data: str): diff --git a/ui/easydiffusion/backends/ed_classic.py b/ui/easydiffusion/backends/ed_classic.py index c9cf745e..07077b6c 100644 --- a/ui/easydiffusion/backends/ed_classic.py +++ b/ui/easydiffusion/backends/ed_classic.py @@ -3,6 +3,7 @@ from sdkit_common import ( stop_backend, install_backend, uninstall_backend, + is_installed, create_sdkit_context, ping, load_model, diff --git a/ui/easydiffusion/backends/ed_diffusers.py b/ui/easydiffusion/backends/ed_diffusers.py index c905652d..b5cf8744 100644 --- a/ui/easydiffusion/backends/ed_diffusers.py +++ b/ui/easydiffusion/backends/ed_diffusers.py @@ -3,6 +3,7 @@ from sdkit_common import ( stop_backend, install_backend, uninstall_backend, + is_installed, create_sdkit_context, ping, load_model, diff --git a/ui/easydiffusion/backends/sdkit_common.py b/ui/easydiffusion/backends/sdkit_common.py index d7a49c3e..4c7eba59 100644 --- a/ui/easydiffusion/backends/sdkit_common.py +++ b/ui/easydiffusion/backends/sdkit_common.py @@ -28,6 +28,10 @@ def uninstall_backend(): pass +def is_installed(): + return True + + def create_sdkit_context(use_diffusers): c = Context() c.test_diffusers = use_diffusers diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index 94d9ff03..95cd7883 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -4,6 +4,7 @@ import subprocess import threading from threading import local import psutil +import shutil from easydiffusion.app import ROOT_DIR, getConfig from easydiffusion.model_manager import get_model_dir @@ -29,10 +30,15 @@ ed_info = { "type": "backend", } +WEBUI_REPO = "https://github.com/lllyasviel/stable-diffusion-webui-forge.git" +WEBUI_COMMIT = "f4d5e8cac16a42fa939e78a0956b4c30e2b47bb5" + BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") +OS_NAME = platform.system() + MODELS_TO_OVERRIDE = { "stable-diffusion": "--ckpt-dir", "vae": "--vae-dir", @@ -46,10 +52,42 @@ MODELS_TO_OVERRIDE = { } backend_process = None +conda = "conda" + + +def locate_conda(): + global conda + + which = "where" if OS_NAME == "Windows" else "which" + conda = subprocess.getoutput(f"{which} conda") + conda = conda.split("\n") + conda = conda[0].strip() + print("conda: ", conda) + + +locate_conda() def install_backend(): - pass + print("Installing the WebUI backend..") + + # create the conda env + run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR) + + # install python 3.10 and git in the conda env + run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR) + + # print info + run_in_conda(["git", "--version"], cwd=ROOT_DIR) + run_in_conda(["python", "--version"], cwd=ROOT_DIR) + + # clone webui + run_in_conda(["git", "clone", WEBUI_REPO, WEBUI_DIR], cwd=ROOT_DIR) + + # install cpu-only torch if the PC doesn't have a graphics card (for Windows and Linux). + # this avoids WebUI installing a CUDA version and trying to activate it + if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card(): + run_in_conda(["python", "-m", "pip", "install", "torch", "torchvision"], cwd=WEBUI_DIR) def start_backend(): @@ -59,6 +97,18 @@ def start_backend(): if not os.path.exists(BACKEND_DIR): install_backend() + if backend_config.get("auto_update", True): + run_in_conda(["git", "add", "-A", "."], cwd=WEBUI_DIR) + run_in_conda(["git", "stash"], cwd=WEBUI_DIR) + run_in_conda(["git", "reset", "--hard"], cwd=WEBUI_DIR) + run_in_conda(["git", "fetch"], cwd=WEBUI_DIR) + run_in_conda(["git", "-c", "advice.detachedHead=false", "checkout", WEBUI_COMMIT], cwd=WEBUI_DIR) + + # hack to prevent webui-macos-env.sh from overwriting the COMMANDLINE_ARGS env variable + mac_webui_file = os.path.join(WEBUI_DIR, "webui-macos-env.sh") + if os.path.exists(mac_webui_file): + os.remove(mac_webui_file) + impl.WEBUI_HOST = backend_config.get("host", "localhost") impl.WEBUI_PORT = backend_config.get("port", "7860") @@ -68,11 +118,11 @@ def start_backend(): def target(): global backend_process - cmd = "webui.bat" if platform.system() == "Windows" else "webui.sh" + cmd = "webui.bat" if OS_NAME == "Windows" else "./webui.sh" while True: print("starting", cmd, WEBUI_DIR) - backend_process = subprocess.Popen([cmd], shell=True, cwd=WEBUI_DIR, env=env) + backend_process = run_in_conda([cmd], cwd=WEBUI_DIR, env=env, wait=False) backend_process.wait() stop_backend() @@ -94,7 +144,51 @@ def stop_backend(): def uninstall_backend(): - pass + shutil.rmtree(BACKEND_DIR) + + +def is_installed(): + if not os.path.exists(BACKEND_DIR) or not os.path.exists(SYSTEM_DIR) or not os.path.exists(WEBUI_DIR): + return True + + env = dict(os.environ) + env.update(get_env()) + + try: + out = check_output_in_conda(["python", "-m", "pip", "show", "torch"], env=env) + return "Version" in out.decode() + except subprocess.CalledProcessError: + pass + + return False + + +def run(cmds: list, cwd=None, env=None, stream_output=True, wait=True): + p = subprocess.Popen(cmds, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + if stream_output: + while True: + output = p.stdout.readline() + output = output.decode() + if output == "" and p.poll() is not None: + break + if output: + print(output, end="") + + if wait: + p.wait() + + return p + + +def run_in_conda(cmds: list, *args, **kwargs): + cmds = [conda, "run", "--no-capture-output", "--prefix", SYSTEM_DIR] + cmds + return run(cmds, *args, **kwargs) + + +def check_output_in_conda(cmds: list, cwd=None, env=None): + cmds = [conda, "run", "--no-capture-output", "--prefix", SYSTEM_DIR] + cmds + return subprocess.check_output(cmds, cwd=cwd, env=env, stderr=subprocess.PIPE) def create_context(): @@ -130,34 +224,62 @@ def get_env(): env_entries = { "PATH": [ - f"{dir}/git/bin", - f"{dir}/python", - f"{dir}/python/Library/bin", - f"{dir}/python/Scripts", - f"{dir}/python/Library/usr/bin", + f"{dir}", + f"{dir}/bin", + f"{dir}/Library/bin", + f"{dir}/Scripts", + f"{dir}/usr/bin", ], "PYTHONPATH": [ - f"{dir}/python", - f"{dir}/python/lib/site-packages", - f"{dir}/python/lib/python3.10/site-packages", + f"{dir}", + f"{dir}/lib/site-packages", + f"{dir}/lib/python3.10/site-packages", ], "PYTHONHOME": [], - "PY_LIBS": [f"{dir}/python/Scripts/Lib", f"{dir}/python/Scripts/Lib/site-packages"], - "PY_PIP": [f"{dir}/python/Scripts"], - "PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"], + "PY_LIBS": [ + f"{dir}/Scripts/Lib", + f"{dir}/Scripts/Lib/site-packages", + f"{dir}/lib", + f"{dir}/lib/python3.10/site-packages", + ], + "PY_PIP": [f"{dir}/Scripts", f"{dir}/bin"], + "PIP_INSTALLER_LOCATION": [], # [f"{dir}/python/get-pip.py"], "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], - "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args}'], + "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args} --skip-torch-cuda-test'], "SKIP_VENV": ["1"], "SD_WEBUI_RESTARTING": ["1"], - "PYTHON": [f"{dir}/python/python"], - "GIT": [f"{dir}/git/bin/git"], } - if platform.system() == "Windows": + if OS_NAME == "Windows": + env_entries["PATH"].append("C:/Windows/System32") + env_entries["PATH"].append("C:/Windows/System32/wbem") env_entries["PYTHONNOUSERSITE"] = ["1"] + env_entries["PYTHON"] = [f"{dir}/python"] + env_entries["GIT"] = [f"{dir}/Library/bin/git"] else: + env_entries["PATH"].append("/bin") + env_entries["PATH"].append("/usr/bin") + env_entries["PATH"].append("/usr/sbin") env_entries["PYTHONNOUSERSITE"] = ["y"] + env_entries["PYTHON"] = [f"{dir}/bin/python"] + env_entries["GIT"] = [f"{dir}/bin/git"] + env_entries["venv_dir"] = ["-"] + + if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card(): + env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" + + if OS_NAME == "Darwin": + # based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/e26abf87ecd1eefd9ab0a198eee56f9c643e4001/webui-macos-env.sh + # hack - have to define these here, otherwise webui-macos-env.sh will overwrite COMMANDLINE_ARGS + env_entries["COMMANDLINE_ARGS"][0] += " --upcast-sampling --no-half-vae --use-cpu interrogate" + env_entries["PYTORCH_ENABLE_MPS_FALLBACK"] = ["1"] + + cpu_name = str(subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])) + if "Intel" in cpu_name: + env_entries["TORCH_COMMAND"] = ["pip install torch==2.1.2 torchvision==0.16.2"] + else: + env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"] env = {} for key, paths in env_entries.items(): @@ -169,6 +291,40 @@ def get_env(): return env +def has_discrete_graphics_card(): + system = OS_NAME + + if system == "Windows": + try: + output = subprocess.check_output( + ["wmic", "path", "win32_videocontroller", "get", "name"], stderr=subprocess.STDOUT + ) + # Filter for discrete graphics cards (NVIDIA, AMD, etc.) + discrete_gpus = ["NVIDIA", "AMD", "ATI"] + return any(gpu in output.decode() for gpu in discrete_gpus) + except subprocess.CalledProcessError: + return False + + elif system == "Linux": + try: + output = subprocess.check_output(["lspci"], stderr=subprocess.STDOUT) + # Check for discrete GPUs (NVIDIA, AMD) + discrete_gpus = ["NVIDIA", "AMD", "Advanced Micro Devices"] + return any(gpu in line for line in output.decode().splitlines() for gpu in discrete_gpus) + except subprocess.CalledProcessError: + return False + + elif system == "Darwin": # macOS + try: + output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], stderr=subprocess.STDOUT) + # Check for discrete GPU in the output + return "NVIDIA" in output.decode() or "AMD" in output.decode() + except subprocess.CalledProcessError: + return False + + return False + + # https://stackoverflow.com/a/25134985 def kill(proc_pid): process = psutil.Process(proc_pid) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 5ad6420d..38cba0b6 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -234,6 +234,7 @@ def thread_render(device): from easydiffusion import model_manager, runtime from easydiffusion.backend_manager import backend + from requests import ConnectionError try: runtime.init(device) @@ -252,7 +253,7 @@ def thread_render(device): break time.sleep(1) - except TimeoutError: + except (TimeoutError, ConnectionError): time.sleep(1) model_manager.load_default_models(runtime.context) From 9abc76482c4a4fc31944cac89e6c2f541a2e464c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 7 Oct 2024 12:39:07 +0530 Subject: [PATCH 14/64] Wait until the webui backend responds with a 200 OK to ping requests --- ui/easydiffusion/backends/webui/impl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py index 393f8ea9..e0adae96 100644 --- a/ui/easydiffusion/backends/webui/impl.py +++ b/ui/easydiffusion/backends/webui/impl.py @@ -68,7 +68,10 @@ def ping(timeout=1): global webui_opts try: - webui_get("/internal/ping", timeout=timeout) + res = webui_get("/internal/ping", timeout=timeout) + + if res.status_code != 200: + raise ConnectTimeout(res.text) if webui_opts is None: try: @@ -85,7 +88,7 @@ def ping(timeout=1): webui_opts = res.json() except Exception as e: - print(f"Error setting options: {e}") + print(f"Error getting options: {e}") return True except ConnectTimeout as e: From b6ba782c35124bfa30de749211c7673da93faaa1 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 7 Oct 2024 12:48:27 +0530 Subject: [PATCH 15/64] Support both WebUI and ED folder names for models --- ui/easydiffusion/backends/webui/__init__.py | 6 +- ui/easydiffusion/model_manager.py | 113 ++++++++++++-------- ui/easydiffusion/server.py | 2 +- ui/easydiffusion/utils/nsfw_checker.py | 4 +- 4 files changed, 75 insertions(+), 50 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index 95cd7883..6b21ae1c 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -7,7 +7,7 @@ import psutil import shutil from easydiffusion.app import ROOT_DIR, getConfig -from easydiffusion.model_manager import get_model_dir +from easydiffusion.model_manager import get_model_dirs from . import impl from .impl import ( @@ -74,6 +74,8 @@ def install_backend(): # create the conda env run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR) + print("Installing packages..") + # install python 3.10 and git in the conda env run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR) @@ -336,7 +338,7 @@ def kill(proc_pid): def get_model_path_args(): args = [] for model_type, flag in MODELS_TO_OVERRIDE.items(): - model_dir = get_model_dir(model_type) + model_dir = get_model_dirs(model_type)[0] args.append(f'{flag} "{model_dir}"') return " ".join(args) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 2a20fce9..60c286c9 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -51,6 +51,16 @@ DEFAULT_MODELS = { ], } MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] +ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility + "stable-diffusion": "Stable-diffusion", + "vae": "VAE", + "hypernetwork": "hypernetworks", + "codeformer": "Codeformer", + "gfpgan": "GFPGAN", + "realesrgan": "RealESRGAN", + "lora": "Lora", + "controlnet": "ControlNet", +} known_models = {} @@ -122,33 +132,33 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dir = get_model_dir(model_type) if not model_name: # When None try user configured model. # config = getConfig() if "model" in config and model_type in config["model"]: model_name = config["model"][model_type] - if model_name: - # Check models directory - model_path = os.path.join(model_dir, model_name) - if os.path.exists(model_path): - return model_path - for model_extension in model_extensions: - if os.path.exists(model_path + model_extension): - return model_path + model_extension - if os.path.exists(model_name + model_extension): - return os.path.abspath(model_name + model_extension) + for model_dir in get_model_dirs(model_type): + if model_name: + # Check models directory + model_path = os.path.join(model_dir, model_name) + if os.path.exists(model_path): + return model_path + for model_extension in model_extensions: + if os.path.exists(model_path + model_extension): + return model_path + model_extension + if os.path.exists(model_name + model_extension): + return os.path.abspath(model_name + model_extension) - # Can't find requested model, check the default paths. - if model_type == "stable-diffusion" and not fail_if_not_found: - for default_model in default_models: - default_model_path = os.path.join(model_dir, default_model["file_name"]) - if os.path.exists(default_model_path): - if model_name is not None: - log.warn( - f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" - ) - return default_model_path + # Can't find requested model, check the default paths. + if model_type == "stable-diffusion" and not fail_if_not_found: + for default_model in default_models: + default_model_path = os.path.join(model_dir, default_model["file_name"]) + if os.path.exists(default_model_path): + if model_name is not None: + log.warn( + f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" + ) + return default_model_path if model_name and fail_if_not_found: raise FileNotFoundError( @@ -239,7 +249,7 @@ def download_default_models_if_necessary(): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): - model_dir = get_model_dir(model_type) + model_dir = get_model_dirs(model_type)[0] model_path = os.path.join(model_dir, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] @@ -260,23 +270,23 @@ def migrate_legacy_model_location(): file_name = model["file_name"] legacy_path = os.path.join(app.SD_DIR, file_name) if os.path.exists(legacy_path): - model_dir = get_model_dir(model_type) + model_dir = get_model_dirs(model_type)[0] shutil.move(legacy_path, os.path.join(model_dir, file_name)) def any_model_exists(model_type: str) -> bool: extensions = MODEL_EXTENSIONS.get(model_type, []) - model_dir = get_model_dir(model_type) - for ext in extensions: - if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): - return True + for model_dir in get_model_dirs(model_type): + for ext in extensions: + if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): + return True return False def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: - model_dir_path = get_model_dir(model_type) + model_dir_path = get_model_dirs(model_type)[0] try: os.makedirs(model_dir_path, exist_ok=True) @@ -377,6 +387,9 @@ def getModels(scan_for_malicious: bool = True): tree = list(default_entries) + if not os.path.exists(directory): + return tree + for entry in sorted( os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), @@ -421,17 +434,23 @@ def getModels(scan_for_malicious: bool = True): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) - models_dir = get_model_dir(model_type) - if not os.path.exists(models_dir): - os.makedirs(models_dir) + models_dirs = get_model_dirs(model_type) + if not os.path.exists(models_dirs[0]): + os.makedirs(models_dirs[0]) - try: - default_tree = models["options"].get(model_type, []) - models["options"][model_type] = scan_directory( - models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter - ) - except MaliciousModelException as e: - models["scan-error"] = str(e) + models["options"][model_type] = [] + default_tree = models["options"].get(model_type, []) + + for model_dir in models_dirs: + try: + scanned_models = scan_directory( + model_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter + ) + for m in scanned_models: + if m not in models["options"][model_type]: + models["options"][model_type].append(m) + except MaliciousModelException as e: + models["scan-error"] = str(e) if scan_for_malicious: log.info(f"[green]Scanning all model folders for models...[/]") @@ -450,17 +469,21 @@ def getModels(scan_for_malicious: bool = True): return models -def get_model_dir(model_type: str, base_dir=None): - "Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)" +def get_model_dirs(model_type: str, base_dir=None): + "Returns the possible model directory paths for the given model type. Mainly used for WebUI compatibility" if base_dir is None: base_dir = app.MODELS_DIR - for dir in os.listdir(base_dir): - if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)): - return os.path.join(base_dir, dir) + dirs = [os.path.join(base_dir, model_type)] - return os.path.join(base_dir, model_type) + if model_type in ALTERNATE_FOLDER_NAMES: + alt_dir = ALTERNATE_FOLDER_NAMES[model_type] + alt_dir = os.path.join(base_dir, alt_dir) + if os.path.exists(alt_dir) and os.path.isdir(alt_dir): + dirs.append(alt_dir) + + return dirs # patch sdkit @@ -468,7 +491,7 @@ def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_mod "Patched version that works with case-insensitive model sub-dirs" download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir - download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir + download_base_dir = get_model_dirs(model_type, download_base_dir)[0] if subdir_for_model_type else download_base_dir return os.path.abspath(download_base_dir) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index f1b85764..63d940aa 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -364,7 +364,7 @@ def model_merge_internal(req: dict): mergeReq: MergeRequest = MergeRequest.parse_obj(req) - sd_model_dir = model_manager.get_model_dir("stable-diffusion") + sd_model_dir = model_manager.get_model_dir("stable-diffusion")[0] merge_models( model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py index 9e371a37..51a684df 100644 --- a/ui/easydiffusion/utils/nsfw_checker.py +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -13,7 +13,7 @@ nsfw_check_model = None def filter_nsfw(images, blur_radius: float = 75, print_log=True): global nsfw_check_model - from easydiffusion.model_manager import get_model_dir + from easydiffusion.model_manager import get_model_dirs from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick import onnxruntime as ort @@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True): import numpy as np if nsfw_check_model is None: - model_dir = get_model_dir("nsfw-checker") + model_dir = get_model_dirs("nsfw-checker")[0] model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") os.makedirs(model_dir, exist_ok=True) From 5a5d37ba5206a738b367e29ce06298b0a810fed2 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 7 Oct 2024 12:59:19 +0530 Subject: [PATCH 16/64] Enable rendering only the CPU --- ui/easydiffusion/backends/webui/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index 6b21ae1c..d35c4539 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -268,7 +268,9 @@ def get_env(): env_entries["GIT"] = [f"{dir}/bin/git"] env_entries["venv_dir"] = ["-"] - if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card(): + if config.get("render_devices", "auto") == "cpu" or ( + OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card() + ): env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" if OS_NAME == "Darwin": From c1193377b6fa510a7f3314456b08e415bbd58c96 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 7 Oct 2024 13:33:25 +0530 Subject: [PATCH 17/64] Use vram_usage_level while starting webui --- ui/easydiffusion/backends/webui/__init__.py | 13 ++++++++----- ui/easydiffusion/server.py | 2 ++ ui/media/js/parameters.js | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index d35c4539..126b74bc 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -268,11 +268,6 @@ def get_env(): env_entries["GIT"] = [f"{dir}/bin/git"] env_entries["venv_dir"] = ["-"] - if config.get("render_devices", "auto") == "cpu" or ( - OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card() - ): - env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" - if OS_NAME == "Darwin": # based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/e26abf87ecd1eefd9ab0a198eee56f9c643e4001/webui-macos-env.sh # hack - have to define these here, otherwise webui-macos-env.sh will overwrite COMMANDLINE_ARGS @@ -284,6 +279,14 @@ def get_env(): env_entries["TORCH_COMMAND"] = ["pip install torch==2.1.2 torchvision==0.16.2"] else: env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"] + else: + vram_usage_level = config.get("vram_usage_level", "balanced") + if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card(): + env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" + elif vram_usage_level == "low": + env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram" + elif vram_usage_level == "high": + env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram" env = {} for key, paths in env_entries.items(): diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 63d940aa..7a82680e 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -71,6 +71,7 @@ class SetAppConfigRequest(BaseModel, extra=Extra.allow): use_v3_engine: bool = True backend: str = "ed_diffusers" models_dir: str = None + vram_usage_level: str = "balanced" def init(): @@ -188,6 +189,7 @@ def set_app_config_internal(req: SetAppConfigRequest): config["use_v3_engine"] = req.backend == "ed_diffusers" config["backend"] = req.backend config["models_dir"] = req.models_dir + config["vram_usage_level"] = req.vram_usage_level for property, property_value in req.dict().items(): if property_value is not None and property not in req.__fields__ and property not in PROTECTED_CONFIG_KEYS: diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index c8624f40..65bc82d5 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -161,6 +161,7 @@ var PARAMETERS = [ "Low: slowest, recommended for GPUs with 3 to 4 GB memory", icon: "fa-forward", default: "balanced", + saveInAppConfig: true, options: [ { value: "balanced", label: "Balanced" }, { value: "high", label: "High" }, From 84c8284a90cd6ad99bb9d078e79414ea34934de5 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 8 Oct 2024 18:35:08 +0530 Subject: [PATCH 18/64] Proxy webui api --- ui/easydiffusion/backends/webui/__init__.py | 48 +++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index 126b74bc..52195fc6 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -132,6 +132,54 @@ def start_backend(): backend_thread = threading.Thread(target=target) backend_thread.start() + start_proxy() + + +def start_proxy(): + # proxy + from easydiffusion.server import server_api + from fastapi import FastAPI, Request + from fastapi.responses import Response + import json + + URI_PREFIX = "/webui" + + webui_proxy = FastAPI(root_path=f"{URI_PREFIX}", docs_url="/swagger") + + @webui_proxy.get("{uri:path}") + def proxy_get(uri: str, req: Request): + if uri == "/openapi-proxy.json": + uri = "/openapi.json" + + res = impl.webui_get(uri, headers=req.headers) + + content = res.content + headers = dict(res.headers) + + if uri == "/docs": + content = res.text.replace("url: '/openapi.json'", f"url: '{URI_PREFIX}/openapi-proxy.json'") + elif uri == "/openapi.json": + content = res.json() + content["paths"] = {f"{URI_PREFIX}{k}": v for k, v in content["paths"].items()} + content = json.dumps(content) + + if isinstance(content, str): + content = bytes(content, encoding="utf-8") + headers["content-length"] = str(len(content)) + + # Return the same response back to the client + return Response(content=content, status_code=res.status_code, headers=headers) + + @webui_proxy.post("{uri:path}") + async def proxy_post(uri: str, req: Request): + body = await req.body() + res = impl.webui_post(uri, data=body, headers=req.headers) + + # Return the same response back to the client + return Response(content=res.content, status_code=res.status_code, headers=dict(res.headers)) + + server_api.mount(f"{URI_PREFIX}", webui_proxy) + def stop_backend(): global backend_process From 90bc1456c9ff62a84231a372cd89414c91ac8c43 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 8 Oct 2024 18:54:06 +0530 Subject: [PATCH 19/64] Suggest guidance value for Flux and non-Flux models --- ui/index.html | 1 + ui/media/css/main.css | 5 +++++ ui/media/js/main.js | 26 ++++++++++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/ui/index.html b/ui/index.html index 2fd11a5c..74f07984 100644 --- a/ui/index.html +++ b/ui/index.html @@ -414,6 +414,7 @@ + + - + + + - + + + diff --git a/ui/media/js/main.js b/ui/media/js/main.js index e5f58edd..e59a6ace 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -1932,13 +1932,52 @@ function checkFluxSampler() { samplerWarning.classList.add("displayNone") } } + +function checkFluxScheduler() { + const badSchedulers = ["automatic", "uniform", "turbo", "align_your_steps", "align_your_steps_GITS", "align_your_steps_11", "align_your_steps_32"] + + let schedulerWarning = document.querySelector("#fluxSchedulerWarning") + if (sdModelField.value.toLowerCase().includes("flux")) { + if (badSchedulers.includes(schedulerField.value)) { + schedulerWarning.classList.remove("displayNone") + } else { + schedulerWarning.classList.add("displayNone") + } + } else { + schedulerWarning.classList.add("displayNone") + } +} + +function checkFluxSchedulerSteps() { + const problematicSchedulers = ["karras", "exponential", "polyexponential"] + + let schedulerWarning = document.querySelector("#fluxSchedulerStepsWarning") + if (sdModelField.value.toLowerCase().includes("flux") && parseInt(numInferenceStepsField.value) < 15) { + if (problematicSchedulers.includes(schedulerField.value)) { + schedulerWarning.classList.remove("displayNone") + } else { + schedulerWarning.classList.add("displayNone") + } + } else { + schedulerWarning.classList.add("displayNone") + } +} sdModelField.addEventListener("change", checkFluxSampler) samplerField.addEventListener("change", checkFluxSampler) +sdModelField.addEventListener("change", checkFluxScheduler) +schedulerField.addEventListener("change", checkFluxScheduler) + +sdModelField.addEventListener("change", checkFluxSchedulerSteps) +schedulerField.addEventListener("change", checkFluxSchedulerSteps) +numInferenceStepsField.addEventListener("change", checkFluxSchedulerSteps) + document.addEventListener("refreshModels", function() { checkGuidanceValue() checkGuidanceScaleVisibility() checkFluxSampler() + checkFluxScheduler() + checkFluxSchedulerSteps() }) // function onControlImageFilterChange() { From 459bfd428098ee29e81a751bdc519be1e0f99076 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 18 Oct 2024 08:30:29 +0530 Subject: [PATCH 60/64] Temp patch for missing attribute --- ui/easydiffusion/runtime.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/easydiffusion/runtime.py b/ui/easydiffusion/runtime.py index accced00..d85839c0 100644 --- a/ui/easydiffusion/runtime.py +++ b/ui/easydiffusion/runtime.py @@ -38,7 +38,7 @@ def set_vram_optimizations(context): config = getConfig() vram_usage_level = config.get("vram_usage_level", "balanced") - if vram_usage_level != context.vram_usage_level: + if hasattr(context, "vram_usage_level") and vram_usage_level != context.vram_usage_level: context.vram_usage_level = vram_usage_level return True From d4ea34a013f75b66cb1bf06504d8038fb3643c82 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 28 Oct 2024 18:34:00 +0530 Subject: [PATCH 61/64] Increase the webui health-check timeout to 30 seconds (from 1 second); Also trap ReadTimeout in the collection of TimeoutError --- ui/easydiffusion/backends/webui/__init__.py | 2 +- ui/easydiffusion/backends/webui/impl.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index af411c11..fda5e4e3 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -126,7 +126,7 @@ def start_backend(): while True: try: - impl.ping(timeout=1) + impl.ping(timeout=30) is_first_start = not has_started has_started = True diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py index 0560e2f8..0e644bd5 100644 --- a/ui/easydiffusion/backends/webui/impl.py +++ b/ui/easydiffusion/backends/webui/impl.py @@ -1,6 +1,6 @@ import os import requests -from requests.exceptions import ConnectTimeout, ConnectionError +from requests.exceptions import ConnectTimeout, ConnectionError, ReadTimeout from typing import Union, List from threading import local as Context from threading import Thread @@ -8,7 +8,7 @@ import uuid import time from copy import deepcopy -from sdkit.utils import base64_str_to_img, img_to_base64_str +from sdkit.utils import base64_str_to_img, img_to_base64_str, log WEBUI_HOST = "localhost" WEBUI_PORT = "7860" @@ -91,7 +91,7 @@ def ping(timeout=1): print(f"Error getting options: {e}") return True - except (ConnectTimeout, ConnectionError) as e: + except (ConnectTimeout, ConnectionError, ReadTimeout) as e: raise TimeoutError(e) From d0be4edf1df71974c08bed75ca913e98599ea277 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 28 Oct 2024 18:34:36 +0530 Subject: [PATCH 62/64] Update to the latest forge commit - b592142f3b46852263747a4efd0d244ad17b5bb3 --- ui/easydiffusion/backends/webui/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index fda5e4e3..947b2e24 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -33,7 +33,7 @@ ed_info = { } WEBUI_REPO = "https://github.com/lllyasviel/stable-diffusion-webui-forge.git" -WEBUI_COMMIT = "f4d5e8cac16a42fa939e78a0956b4c30e2b47bb5" +WEBUI_COMMIT = "b592142f3b46852263747a4efd0d244ad17b5bb3" BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") From 96ec3ed270850079e891e489a5cb461c34818ff7 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 13 Nov 2024 21:36:35 +0530 Subject: [PATCH 63/64] Pin huggingface-hub to 0.23.2 to fix broken deployments --- scripts/check_modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 0d58d43d..aa106169 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -35,6 +35,7 @@ modules_to_check = { "python-multipart": "0.0.6", # "xformers": "0.0.16", "onnxruntime": "1.19.2", + "huggingface-hub": "0.23.2", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"] From 3c9ffcf7ca5aa590dad8c62dddc7f4b639fb6908 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 13 Nov 2024 21:49:17 +0530 Subject: [PATCH 64/64] Update check_modules.py --- scripts/check_modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index aa106169..50887ea5 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -35,7 +35,7 @@ modules_to_check = { "python-multipart": "0.0.6", # "xformers": "0.0.16", "onnxruntime": "1.19.2", - "huggingface-hub": "0.23.2", + "huggingface-hub": "0.21.4", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]

    diff --git a/ui/media/css/main.css b/ui/media/css/main.css index c5e862fe..aabdccf9 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -670,6 +670,11 @@ div.img-preview img { display: none; } +#guidanceWarningText { + color: var(--status-orange); + font-size: 9pt; +} + .display-settings { float: right; position: relative; diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 2b90a6d3..a1023286 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -1858,6 +1858,32 @@ controlImagePreview.addEventListener("load", onControlnetModelChange) controlImagePreview.addEventListener("unload", onControlnetModelChange) onControlnetModelChange() +// tip for Flux +let sdModelField = document.querySelector("#stable_diffusion_model") +function checkGuidanceValue() { + let guidance = parseFloat(guidanceScaleField.value) + let guidanceWarning = document.querySelector("#guidanceWarning") + let guidanceWarningText = document.querySelector("#guidanceWarningText") + if (sdModelField.value.toLowerCase().includes("flux")) { + if (guidance > 1.5) { + guidanceWarningText.innerText = "Flux recommends a guidance scale of 1" + guidanceWarning.classList.remove("displayNone") + } else { + guidanceWarning.classList.add("displayNone") + } + } else { + if (guidance < 2) { + guidanceWarningText.innerText = "A higher Guidance Scale is recommended!" + guidanceWarning.classList.remove("displayNone") + } else { + guidanceWarning.classList.add("displayNone") + } + } +} +sdModelField.addEventListener("change", checkGuidanceValue) +guidanceScaleField.addEventListener("change", checkGuidanceValue) +guidanceScaleSlider.addEventListener("change", checkGuidanceValue) + // function onControlImageFilterChange() { // let filterId = controlImageFilterField.value // if (filterId.includes("openpose")) { From d283fb0776ef0d0ced03ceaa67c34c71a62a6160 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 9 Oct 2024 11:12:28 +0530 Subject: [PATCH 20/64] Option to select Distilled Guidance Scale; Show warnings for Euler Ancestral sampler with Flux --- ui/easydiffusion/backends/sdkit_common.py | 1 + ui/easydiffusion/backends/webui/impl.py | 2 + ui/easydiffusion/types.py | 1 + ui/index.html | 6 +- ui/media/css/main.css | 9 +-- ui/media/js/main.js | 68 ++++++++++++++++++++++- 6 files changed, 76 insertions(+), 11 deletions(-) diff --git a/ui/easydiffusion/backends/sdkit_common.py b/ui/easydiffusion/backends/sdkit_common.py index 4c7eba59..bda1414c 100644 --- a/ui/easydiffusion/backends/sdkit_common.py +++ b/ui/easydiffusion/backends/sdkit_common.py @@ -81,6 +81,7 @@ def generate_images( context: Context, callback=None, controlnet_filter=None, + distilled_guidance_scale: float = 3.5, output_type="pil", **req, ): diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py index e0adae96..6c8c66bd 100644 --- a/ui/easydiffusion/backends/webui/impl.py +++ b/ui/easydiffusion/backends/webui/impl.py @@ -150,6 +150,7 @@ def generate_images( num_outputs: int = 1, num_inference_steps: int = 25, guidance_scale: float = 7.5, + distilled_guidance_scale: float = 3.5, init_image=None, init_image_mask=None, control_image=None, @@ -181,6 +182,7 @@ def generate_images( "steps": num_inference_steps, "seed": seed, "cfg_scale": guidance_scale, + "distilled_cfg_scale": distilled_guidance_scale, "batch_size": num_outputs, "width": width, "height": height, diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 93bfe08f..c26e4fbc 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -14,6 +14,7 @@ class GenerateImageRequest(BaseModel): num_outputs: int = 1 num_inference_steps: int = 50 guidance_scale: float = 7.5 + distilled_guidance_scale: float = 3.5 init_image: Any = None init_image_mask: Any = None diff --git a/ui/index.html b/ui/index.html index 74f07984..1284d785 100644 --- a/ui/index.html +++ b/ui/index.html @@ -336,6 +336,7 @@ Click to learn more about samplers
    Please avoid 'Euler Ancestral' with Flux!

    diff --git a/ui/media/css/main.css b/ui/media/css/main.css index aabdccf9..39b7b6e4 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -670,9 +670,9 @@ div.img-preview img { display: none; } -#guidanceWarningText { +.warning-label { + font-size: smaller; color: var(--status-orange); - font-size: 9pt; } .display-settings { @@ -1477,11 +1477,6 @@ div.top-right { margin-top: 6px; } -#small_image_warning { - font-size: smaller; - color: var(--status-orange); -} - button#save-system-settings-btn { padding: 4pt 8pt; } diff --git a/ui/media/js/main.js b/ui/media/js/main.js index a1023286..3f372911 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -14,6 +14,10 @@ const taskConfigSetup = { sampler_name: "Sampler", num_inference_steps: "Inference Steps", guidance_scale: "Guidance Scale", + distilled_guidance_scale: { + label: "Distilled Guidance Scale", + visible: ({ reqBody }) => reqBody?.distilled_guidance_scale, + }, use_stable_diffusion_model: "Model", clip_skip: { label: "Clip Skip", @@ -76,6 +80,8 @@ let numOutputsParallelField = document.querySelector("#num_outputs_parallel") let numInferenceStepsField = document.querySelector("#num_inference_steps") let guidanceScaleSlider = document.querySelector("#guidance_scale_slider") let guidanceScaleField = document.querySelector("#guidance_scale") +let distilledGuidanceScaleSlider = document.querySelector("#distilled_guidance_scale_slider") +let distilledGuidanceScaleField = document.querySelector("#distilled_guidance_scale") let outputQualitySlider = document.querySelector("#output_quality_slider") let outputQualityField = document.querySelector("#output_quality") let outputQualityRow = document.querySelector("#output_quality_row") @@ -1051,6 +1057,9 @@ function makeImage() { if (guidanceScaleField.value == "") { guidanceScaleField.value = guidanceScaleSlider.value / 10 } + if (distilledGuidanceScaleField.value == "") { + distilledGuidanceScaleField.value = distilledGuidanceScaleSlider.value / 10 + } if (hypernetworkStrengthField.value == "") { hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100 } @@ -1419,6 +1428,9 @@ function getCurrentUserRequest() { newTask.reqBody.control_filter_to_apply = controlImageFilterField.value } } + if (stableDiffusionModelField.value.toLowerCase().includes("flux")) { + newTask.reqBody.distilled_guidance_scale = parseFloat(distilledGuidanceScaleField.value) + } return newTask } @@ -1866,14 +1878,14 @@ function checkGuidanceValue() { let guidanceWarningText = document.querySelector("#guidanceWarningText") if (sdModelField.value.toLowerCase().includes("flux")) { if (guidance > 1.5) { - guidanceWarningText.innerText = "Flux recommends a guidance scale of 1" + guidanceWarningText.innerText = "Flux recommends a 'Guidance Scale' of 1" guidanceWarning.classList.remove("displayNone") } else { guidanceWarning.classList.add("displayNone") } } else { if (guidance < 2) { - guidanceWarningText.innerText = "A higher Guidance Scale is recommended!" + guidanceWarningText.innerText = "A higher 'Guidance Scale' is recommended!" guidanceWarning.classList.remove("displayNone") } else { guidanceWarning.classList.add("displayNone") @@ -1884,6 +1896,37 @@ sdModelField.addEventListener("change", checkGuidanceValue) guidanceScaleField.addEventListener("change", checkGuidanceValue) guidanceScaleSlider.addEventListener("change", checkGuidanceValue) +function checkGuidanceScaleVisibility() { + let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container") + if (sdModelField.value.toLowerCase().includes("flux")) { + guidanceScaleContainer.classList.remove("displayNone") + } else { + guidanceScaleContainer.classList.add("displayNone") + } +} +sdModelField.addEventListener("change", checkGuidanceScaleVisibility) + +function checkFluxSampler() { + let samplerWarning = document.querySelector("#fluxSamplerWarning") + if (sdModelField.value.toLowerCase().includes("flux")) { + if (samplerField.value == "euler_a") { + samplerWarning.classList.remove("displayNone") + } else { + samplerWarning.classList.add("displayNone") + } + } else { + samplerWarning.classList.add("displayNone") + } +} +sdModelField.addEventListener("change", checkFluxSampler) +samplerField.addEventListener("change", checkFluxSampler) + +document.addEventListener("refreshModels", function() { + checkGuidanceValue() + checkGuidanceScaleVisibility() + checkFluxSampler() +}) + // function onControlImageFilterChange() { // let filterId = controlImageFilterField.value // if (filterId.includes("openpose")) { @@ -2012,6 +2055,27 @@ guidanceScaleSlider.addEventListener("input", updateGuidanceScale) guidanceScaleField.addEventListener("input", updateGuidanceScaleSlider) updateGuidanceScale() +/********************* Distilled Guidance **************************/ +function updateDistilledGuidanceScale() { + distilledGuidanceScaleField.value = distilledGuidanceScaleSlider.value / 10 + distilledGuidanceScaleField.dispatchEvent(new Event("change")) +} + +function updateDistilledGuidanceScaleSlider() { + if (distilledGuidanceScaleField.value < 0) { + distilledGuidanceScaleField.value = 0 + } else if (distilledGuidanceScaleField.value > 50) { + distilledGuidanceScaleField.value = 50 + } + + distilledGuidanceScaleSlider.value = distilledGuidanceScaleField.value * 10 + distilledGuidanceScaleSlider.dispatchEvent(new Event("change")) +} + +distilledGuidanceScaleSlider.addEventListener("input", updateDistilledGuidanceScale) +distilledGuidanceScaleField.addEventListener("input", updateDistilledGuidanceScaleSlider) +updateDistilledGuidanceScale() + /********************* Prompt Strength *******************/ function updatePromptStrength() { promptStrengthField.value = promptStrengthSlider.value / 100 From 3327244da277a806e05c596f94248e62770e2103 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 9 Oct 2024 11:43:45 +0530 Subject: [PATCH 21/64] Scheduler selection in the UI; Store Scheduler and Distilled Guidance in metadata --- ui/easydiffusion/backends/sdkit_common.py | 1 + ui/easydiffusion/backends/webui/impl.py | 3 ++- ui/easydiffusion/types.py | 1 + ui/easydiffusion/utils/save_utils.py | 2 ++ ui/index.html | 20 ++++++++++++++++++++ ui/media/js/auto-save.js | 2 ++ ui/media/js/dnd.js | 11 +++++++++++ ui/media/js/main.js | 9 +++++++++ 8 files changed, 48 insertions(+), 1 deletion(-) diff --git a/ui/easydiffusion/backends/sdkit_common.py b/ui/easydiffusion/backends/sdkit_common.py index bda1414c..1feae994 100644 --- a/ui/easydiffusion/backends/sdkit_common.py +++ b/ui/easydiffusion/backends/sdkit_common.py @@ -82,6 +82,7 @@ def generate_images( callback=None, controlnet_filter=None, distilled_guidance_scale: float = 3.5, + scheduler_name: str = "simple", output_type="pil", **req, ): diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py index 6c8c66bd..2853e8fb 100644 --- a/ui/easydiffusion/backends/webui/impl.py +++ b/ui/easydiffusion/backends/webui/impl.py @@ -160,6 +160,7 @@ def generate_images( preserve_init_image_color_profile=False, strict_mask_border=False, sampler_name: str = "euler_a", + scheduler_name: str = "simple", hypernetwork_strength: float = 0, tiling=None, lora_alpha: Union[float, List[float]] = 0, @@ -178,7 +179,7 @@ def generate_images( "prompt": prompt, "negative_prompt": negative_prompt, "sampler_name": sampler_name, - "scheduler": "simple", + "scheduler": scheduler_name, "steps": num_inference_steps, "seed": seed, "cfg_scale": guidance_scale, diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index c26e4fbc..bc0ccabf 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -26,6 +26,7 @@ class GenerateImageRequest(BaseModel): strict_mask_border: bool = False sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + scheduler_name: str = None hypernetwork_strength: float = 0 lora_alpha: Union[float, List[float]] = 0 tiling: str = None # None, "x", "y", "xy" diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 216ec899..29c84f22 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -34,10 +34,12 @@ TASK_TEXT_MAPPING = { "control_alpha": "ControlNet Strength", "use_vae_model": "VAE model", "sampler_name": "Sampler", + "scheduler_name": "Scheduler", "width": "Width", "height": "Height", "num_inference_steps": "Steps", "guidance_scale": "Guidance Scale", + "distilled_guidance_scale": "Distilled Guidance", "prompt_strength": "Prompt Strength", "use_lora_model": "LoRA model", "lora_alpha": "LoRA Strength", diff --git a/ui/index.html b/ui/index.html index 1284d785..72abe26c 100644 --- a/ui/index.html +++ b/ui/index.html @@ -337,6 +337,26 @@ Click to learn more about samplers
    Please avoid 'Euler Ancestral' with Flux!
    + +
    Click to learn more about samplers
    Please avoid 'Euler Ancestral' with Flux!
    Tip:This sampler does not work well with Flux!
    Tip:This scheduler does not work well with Flux!
    Tip:This scheduler needs 15 steps or more