From 9a12a8618cb426c99e4c6175d8234339fc94a5fe Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Oct 2024 10:54:58 +0530 Subject: [PATCH] First working version of dynamic backends, with Forge and ed_diffusers (v3) and ed_classic (v2). Does not auto-install Forge yet --- ui/easydiffusion/app.py | 23 +- ui/easydiffusion/backend_manager.py | 105 ++++ ui/easydiffusion/backends/ed_classic.py | 27 + ui/easydiffusion/backends/ed_diffusers.py | 27 + ui/easydiffusion/backends/sdkit_common.py | 237 ++++++++ ui/easydiffusion/backends/webui/__init__.py | 155 +++++ ui/easydiffusion/backends/webui/impl.py | 639 ++++++++++++++++++++ ui/easydiffusion/easydb/schemas.py | 1 - ui/easydiffusion/model_manager.py | 72 ++- ui/easydiffusion/runtime.py | 34 +- ui/easydiffusion/server.py | 24 +- ui/easydiffusion/task_manager.py | 15 +- ui/easydiffusion/tasks/filter_images.py | 81 +-- ui/easydiffusion/tasks/render_images.py | 224 +++---- ui/easydiffusion/types.py | 67 +- ui/easydiffusion/utils/__init__.py | 3 +- ui/easydiffusion/utils/nsfw_checker.py | 80 +++ ui/easydiffusion/utils/save_utils.py | 2 +- ui/index.html | 131 +++- ui/main.py | 3 - ui/media/css/auto-save.css | 3 +- ui/media/css/main.css | 15 +- ui/media/js/main.js | 15 +- ui/media/js/parameters.js | 88 +-- 24 files changed, 1715 insertions(+), 356 deletions(-) create mode 100644 ui/easydiffusion/backend_manager.py create mode 100644 ui/easydiffusion/backends/ed_classic.py create mode 100644 ui/easydiffusion/backends/ed_diffusers.py create mode 100644 ui/easydiffusion/backends/sdkit_common.py create mode 100644 ui/easydiffusion/backends/webui/__init__.py create mode 100644 ui/easydiffusion/backends/webui/impl.py create mode 100644 ui/easydiffusion/utils/nsfw_checker.py diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index a1f607ef..ec856991 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -11,7 +11,7 @@ from ruamel.yaml import YAML import urllib import warnings -from easydiffusion import task_manager +from easydiffusion import task_manager, backend_manager from easydiffusion.utils import log from rich.logging import RichHandler from rich.console import Console @@ -60,7 +60,7 @@ APP_CONFIG_DEFAULTS = { "ui": { "open_browser_on_start": True, }, - "use_v3_engine": True, + "backend": "ed_diffusers", } IMAGE_EXTENSIONS = [ @@ -108,6 +108,8 @@ def init(): if config_models_dir is not None and config_models_dir != "": MODELS_DIR = config_models_dir + backend_manager.start_backend() + def init_render_threads(): load_server_plugins() @@ -124,9 +126,9 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): shutil.move(config_legacy_yaml, config_yaml_path) def set_config_on_startup(config: dict): - if getConfig.__use_v3_engine_on_startup is None: - getConfig.__use_v3_engine_on_startup = config.get("use_v3_engine", True) - config["config_on_startup"] = {"use_v3_engine": getConfig.__use_v3_engine_on_startup} + if getConfig.__use_backend_on_startup is None: + getConfig.__use_backend_on_startup = config.get("backend", "ed_diffusers") + config["config_on_startup"] = {"backend": getConfig.__use_backend_on_startup} if os.path.isfile(config_yaml_path): try: @@ -144,6 +146,15 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): else: config["net"]["listen_to_network"] = True + if "backend" not in config: + if "use_v3_engine" in config: + config["backend"] = "ed_diffusers" if config["use_v3_engine"] else "ed_classic" + else: + config["backend"] = "ed_diffusers" + # this default will need to be smarter when WebUI becomes the main backend, but needs to maintain backwards + # compatibility with existing ED 3.0 installations that haven't opted into the WebUI backend, and haven't + # set a "use_v3_engine" flag in their config + set_config_on_startup(config) return config @@ -174,7 +185,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): return default_val -getConfig.__use_v3_engine_on_startup = None +getConfig.__use_backend_on_startup = None def setConfig(config): diff --git a/ui/easydiffusion/backend_manager.py b/ui/easydiffusion/backend_manager.py new file mode 100644 index 00000000..1c280db4 --- /dev/null +++ b/ui/easydiffusion/backend_manager.py @@ -0,0 +1,105 @@ +import os +import ast +import sys +import importlib.util +import traceback + +from easydiffusion.utils import log + +backend = None +curr_backend_name = None + + +def is_valid_backend(file_path): + with open(file_path, "r", encoding="utf-8") as file: + node = ast.parse(file.read()) + + # Check for presence of a dictionary named 'ed_info' + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name) and target.id == "ed_info": + return True + return False + + +def find_valid_backends(root_dir) -> dict: + backends_path = os.path.join(root_dir, "backends") + valid_backends = {} + + if not os.path.exists(backends_path): + return valid_backends + + for item in os.listdir(backends_path): + item_path = os.path.join(backends_path, item) + + if os.path.isdir(item_path): + init_file = os.path.join(item_path, "__init__.py") + if os.path.exists(init_file) and is_valid_backend(init_file): + valid_backends[item] = item_path + elif item.endswith(".py"): + if is_valid_backend(item_path): + backend_name = os.path.splitext(item)[0] # strip the .py extension + valid_backends[backend_name] = item_path + + return valid_backends + + +def load_backend_module(backend_name, backend_dict): + if backend_name not in backend_dict: + raise ValueError(f"Backend '{backend_name}' not found.") + + module_path = backend_dict[backend_name] + + mod_dir = os.path.dirname(module_path) + + sys.path.insert(0, mod_dir) + + # If it's a package (directory), add its parent directory to sys.path + if os.path.isdir(module_path): + module_path = os.path.join(module_path, "__init__.py") + + spec = importlib.util.spec_from_file_location(backend_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if mod_dir in sys.path: + sys.path.remove(mod_dir) + + log.info(f"Loaded backend: {module}") + + return module + + +def start_backend(): + global backend, curr_backend_name + + from easydiffusion.app import getConfig, ROOT_DIR + + curr_dir = os.path.dirname(__file__) + + backends = find_valid_backends(curr_dir) + plugin_backends = find_valid_backends(ROOT_DIR) + backends.update(plugin_backends) + + config = getConfig() + backend_name = config["backend"] + + if backend_name not in backends: + raise RuntimeError( + f"Couldn't find the backend configured in config.yaml: {backend_name}. Please check the name!" + ) + + if backend is not None and backend_name != curr_backend_name: + try: + backend.stop_backend() + except: + log.exception(traceback.format_exc()) + + log.info(f"Loading backend: {backend_name}") + backend = load_backend_module(backend_name, backends) + + try: + backend.start_backend() + except: + log.exception(traceback.format_exc()) diff --git a/ui/easydiffusion/backends/ed_classic.py b/ui/easydiffusion/backends/ed_classic.py new file mode 100644 index 00000000..c9cf745e --- /dev/null +++ b/ui/easydiffusion/backends/ed_classic.py @@ -0,0 +1,27 @@ +from sdkit_common import ( + start_backend, + stop_backend, + install_backend, + uninstall_backend, + create_sdkit_context, + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + +ed_info = { + "name": "Classic backend for Easy Diffusion v2", + "version": (1, 0, 0), + "type": "backend", +} + + +def create_context(): + return create_sdkit_context(use_diffusers=False) diff --git a/ui/easydiffusion/backends/ed_diffusers.py b/ui/easydiffusion/backends/ed_diffusers.py new file mode 100644 index 00000000..c905652d --- /dev/null +++ b/ui/easydiffusion/backends/ed_diffusers.py @@ -0,0 +1,27 @@ +from sdkit_common import ( + start_backend, + stop_backend, + install_backend, + uninstall_backend, + create_sdkit_context, + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + +ed_info = { + "name": "Diffusers Backend for Easy Diffusion v3", + "version": (1, 0, 0), + "type": "backend", +} + + +def create_context(): + return create_sdkit_context(use_diffusers=True) diff --git a/ui/easydiffusion/backends/sdkit_common.py b/ui/easydiffusion/backends/sdkit_common.py new file mode 100644 index 00000000..d7a49c3e --- /dev/null +++ b/ui/easydiffusion/backends/sdkit_common.py @@ -0,0 +1,237 @@ +from sdkit import Context + +from easydiffusion.types import UserInitiatedStop + +from sdkit.utils import ( + diffusers_latent_samples_to_images, + gc, + img_to_base64_str, + latent_samples_to_images, +) + +opts = {} + + +def install_backend(): + pass + + +def start_backend(): + print("Started sdkit backend") + + +def stop_backend(): + pass + + +def uninstall_backend(): + pass + + +def create_sdkit_context(use_diffusers): + c = Context() + c.test_diffusers = use_diffusers + return c + + +def ping(timeout=1): + return True + + +def load_model(context, model_type, **kwargs): + from sdkit.models import load_model + + load_model(context, model_type, **kwargs) + + +def unload_model(context, model_type, **kwargs): + from sdkit.models import unload_model + + unload_model(context, model_type, **kwargs) + + +def set_options(context, **kwargs): + if "vae_tiling" in kwargs and context.test_diffusers: + pipe = context.models["stable-diffusion"]["default"] + vae_tiling = kwargs["vae_tiling"] + + if vae_tiling: + if hasattr(pipe, "enable_vae_tiling"): + pipe.enable_vae_tiling() + else: + if hasattr(pipe, "disable_vae_tiling"): + pipe.disable_vae_tiling() + + for key in ( + "output_format", + "output_quality", + "output_lossless", + "stream_image_progress", + "stream_image_progress_interval", + ): + if key in kwargs: + opts[key] = kwargs[key] + + +def generate_images( + context: Context, + callback=None, + controlnet_filter=None, + output_type="pil", + **req, +): + from sdkit.generate import generate_images + + if req["init_image"] is not None and not context.test_diffusers: + req["sampler_name"] = "ddim" + + gc(context) + + context.stop_processing = False + + if req["control_image"] and controlnet_filter: + controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter) + req["control_image"] = filter_images(context, req["control_image"], controlnet_filter)[0] + + callback = make_step_callback(context, callback) + + try: + images = generate_images(context, callback=callback, **req) + except UserInitiatedStop: + images = [] + if context.partial_x_samples is not None: + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) + finally: + if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: + if not context.test_diffusers: + del context.partial_x_samples + context.partial_x_samples = None + + gc(context) + + if output_type == "base64": + output_format = opts.get("output_format", "jpeg") + output_quality = opts.get("output_quality", 75) + output_lossless = opts.get("output_lossless", False) + images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images] + + return images + + +def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"): + gc(context) + + if "nsfw_checker" in filters: + filters.remove("nsfw_checker") # handled by ED directly + + images = _filter_images(context, images, filters, filter_params) + + if input_type == "base64": + output_format = opts.get("output_format", "jpg") + output_quality = opts.get("output_quality", 75) + output_lossless = opts.get("output_lossless", False) + images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images] + + return images + + +def _filter_images(context, images, filters, filter_params={}): + from sdkit.filter import apply_filters + + filters = filters if isinstance(filters, list) else [filters] + filters = convert_ED_controlnet_filter_name(filters) + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + previous_state = before_filter(context, filter_name, params) + + try: + images = apply_filters(context, filter_name, images, **params) + finally: + after_filter(context, filter_name, params, previous_state) + + return images + + +def before_filter(context, filter_name, filter_params): + if filter_name == "codeformer": + from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use + + default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] + prev_realesrgan_path = None + + upscale_faces = filter_params.get("upscale_faces", False) + if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: + prev_realesrgan_path = context.model_paths.get("realesrgan") + context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") + load_model(context, "realesrgan") + + return prev_realesrgan_path + + +def after_filter(context, filter_name, filter_params, previous_state): + if filter_name == "codeformer": + prev_realesrgan_path = previous_state + if prev_realesrgan_path: + context.model_paths["realesrgan"] = prev_realesrgan_path + load_model(context, "realesrgan") + + +def get_url(): + pass + + +def stop_rendering(context): + context.stop_processing = True + + +def refresh_models(): + pass + + +def list_controlnet_filters(): + from sdkit.models.model_loader.controlnet_filters import filters as cn_filters + + return cn_filters + + +def make_step_callback(context, callback): + def on_step(x_samples, i, *args): + stream_image_progress = opts.get("stream_image_progress", False) + stream_image_progress_interval = opts.get("stream_image_progress_interval", 3) + + if context.test_diffusers: + context.partial_x_samples = (x_samples, args[0]) + else: + context.partial_x_samples = x_samples + + if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) + else: + images = None + + if callback: + callback(images, i, *args) + + if context.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + + return on_step + + +def convert_ED_controlnet_filter_name(filter): + def cn(n): + if n.startswith("controlnet_"): + return n[len("controlnet_") :] + return n + + if isinstance(filter, list): + return [cn(f) for f in filter] + return cn(filter) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py new file mode 100644 index 00000000..f78d1164 --- /dev/null +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -0,0 +1,155 @@ +import os +import platform +import subprocess +import threading +from threading import local +import psutil + +from easydiffusion.app import ROOT_DIR, getConfig + +from . import impl +from .impl import ( + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + + +ed_info = { + "name": "WebUI backend for Easy Diffusion", + "version": (1, 0, 0), + "type": "backend", +} + +BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) +SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") +WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") + +backend_process = None + + +def install_backend(): + pass + + +def start_backend(): + config = getConfig() + backend_config = config.get("backend_config", {}) + + if not os.path.exists(BACKEND_DIR): + install_backend() + + impl.WEBUI_HOST = backend_config.get("host", "localhost") + impl.WEBUI_PORT = backend_config.get("port", "7860") + + env = dict(os.environ) + env.update(get_env()) + + def target(): + global backend_process + + cmd = "webui.bat" if platform.system() == "Windows" else "webui.sh" + print("starting", cmd, WEBUI_DIR) + backend_process = subprocess.Popen([cmd], shell=True, cwd=WEBUI_DIR, env=env) + + backend_thread = threading.Thread(target=target) + backend_thread.start() + + +def stop_backend(): + global backend_process + + if backend_process: + kill(backend_process.pid) + + backend_process = None + + +def uninstall_backend(): + pass + + +def create_context(): + context = local() + + # temp hack, throws an attribute not found error otherwise + context.device = "cuda:0" + context.half_precision = True + context.vram_usage_level = None + + context.models = {} + context.model_paths = {} + context.model_configs = {} + context.device_name = None + context.vram_optimizations = set() + context.vram_usage_level = "balanced" + context.test_diffusers = False + context.enable_codeformer = False + + return context + + +def get_env(): + dir = os.path.abspath(SYSTEM_DIR) + + if not os.path.exists(dir): + raise RuntimeError("The system folder is missing!") + + config = getConfig() + models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) + embeddings_dir = os.path.join(models_dir, "embeddings") + + env_entries = { + "PATH": [ + f"{dir}/git/bin", + f"{dir}/python", + f"{dir}/python/Library/bin", + f"{dir}/python/Scripts", + f"{dir}/python/Library/usr/bin", + ], + "PYTHONPATH": [ + f"{dir}/python", + f"{dir}/python/lib/site-packages", + f"{dir}/python/lib/python3.10/site-packages", + ], + "PYTHONHOME": [], + "PY_LIBS": [f"{dir}/python/Scripts/Lib", f"{dir}/python/Scripts/Lib/site-packages"], + "PY_PIP": [f"{dir}/python/Scripts"], + "PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"], + "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], + "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], + "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'], + "SKIP_VENV": ["1"], + "SD_WEBUI_RESTARTING": ["1"], + "PYTHON": [f"{dir}/python/python"], + "GIT": [f"{dir}/git/bin/git"], + } + + if platform.system() == "Windows": + env_entries["PYTHONNOUSERSITE"] = ["1"] + else: + env_entries["PYTHONNOUSERSITE"] = ["y"] + + env = {} + for key, paths in env_entries.items(): + paths = [p.replace("/", os.path.sep) for p in paths] + paths = os.pathsep.join(paths) + + env[key] = paths + + return env + + +# https://stackoverflow.com/a/25134985 +def kill(proc_pid): + process = psutil.Process(proc_pid) + for proc in process.children(recursive=True): + proc.kill() + process.kill() diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py new file mode 100644 index 00000000..c90970ec --- /dev/null +++ b/ui/easydiffusion/backends/webui/impl.py @@ -0,0 +1,639 @@ +import os +import requests +from requests.exceptions import ConnectTimeout +from typing import Union, List +from threading import local as Context +from threading import Thread +import uuid +import time +from copy import deepcopy + +from sdkit.utils import base64_str_to_img, img_to_base64_str + +WEBUI_HOST = "localhost" +WEBUI_PORT = "7860" + +DEFAULT_WEBUI_OPTIONS = {"show_progress_every_n_steps": 3, "show_progress_grid": True, "live_previews_enable": False} + + +webui_opts: dict = None + + +curr_models = { + "stable-diffusion": None, + "vae": None, +} + + +def set_options(context, **kwargs): + changed_opts = {} + + opts_mapping = { + "stream_image_progress": ("live_previews_enable", bool), + "stream_image_progress_interval": ("show_progress_every_n_steps", int), + "clip_skip": ("CLIP_stop_at_last_layers", int), + "clip_skip_sdxl": ("sdxl_clip_l_skip", bool), + "output_format": ("samples_format", str), + } + + for ed_key, webui_key in opts_mapping.items(): + webui_key, webui_type = webui_key + + if ed_key in kwargs and (webui_opts is None or webui_opts.get(webui_key, False) != webui_type(kwargs[ed_key])): + changed_opts[webui_key] = webui_type(kwargs[ed_key]) + + if changed_opts: + changed_opts["sd_model_checkpoint"] = curr_models["stable-diffusion"] + + print(f"Got options: {kwargs}. Sending options: {changed_opts}") + + try: + res = webui_post("/sdapi/v1/options", json=changed_opts) + if res.status_code != 200: + raise Exception(res.text) + + webui_opts.update(changed_opts) + except Exception as e: + print(f"Error setting options: {e}") + + +def ping(timeout=1): + "timeout (in seconds)" + + global webui_opts + + try: + webui_get("/internal/ping", timeout=timeout) + + if webui_opts is None: + try: + res = webui_post("/sdapi/v1/options", json=DEFAULT_WEBUI_OPTIONS) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + print(f"Error setting options: {e}") + + try: + res = webui_get("/sdapi/v1/options") + if res.status_code != 200: + raise Exception(res.text) + + webui_opts = res.json() + except Exception as e: + print(f"Error setting options: {e}") + + return True + except ConnectTimeout as e: + raise TimeoutError(e) + + +def load_model(context, model_type, **kwargs): + model_path = context.model_paths[model_type] + + if webui_opts is None: + print("Server not ready, can't set the model") + return + + if model_type == "stable-diffusion": + model_name = os.path.basename(model_path) + model_name = os.path.splitext(model_name)[0] + print(f"setting sd model: {model_name}") + if curr_models[model_type] != model_name: + try: + res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name}) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) + + curr_models[model_type] = model_name + elif model_type == "vae": + if curr_models[model_type] != model_path: + vae_model = [model_path] if model_path else [] + + opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model} + print("setting opts 2", opts) + + try: + res = webui_post("/sdapi/v1/options", json=opts) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) + + curr_models[model_type] = model_path + + +def unload_model(context, model_type, **kwargs): + pass + + +def generate_images( + context: Context, + prompt: str = "", + negative_prompt: str = "", + seed: int = 42, + width: int = 512, + height: int = 512, + num_outputs: int = 1, + num_inference_steps: int = 25, + guidance_scale: float = 7.5, + init_image=None, + init_image_mask=None, + control_image=None, + control_alpha=1.0, + controlnet_filter=None, + prompt_strength: float = 0.8, + preserve_init_image_color_profile=False, + strict_mask_border=False, + sampler_name: str = "euler_a", + hypernetwork_strength: float = 0, + tiling=None, + lora_alpha: Union[float, List[float]] = 0, + sampler_params={}, + callback=None, + output_type="pil", +): + + task_id = str(uuid.uuid4()) + + sampler_name = convert_ED_sampler_names(sampler_name) + controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter) + + cmd = { + "force_task_id": task_id, + "prompt": prompt, + "negative_prompt": negative_prompt, + "sampler_name": sampler_name, + "scheduler": "simple", + "steps": num_inference_steps, + "seed": seed, + "cfg_scale": guidance_scale, + "batch_size": num_outputs, + "width": width, + "height": height, + } + + if init_image: + cmd["init_images"] = [init_image] + cmd["denoising_strength"] = prompt_strength + if init_image_mask: + cmd["mask"] = init_image_mask + cmd["include_init_images"] = True + cmd["inpainting_fill"] = 1 + cmd["initial_noise_multiplier"] = 1 + cmd["inpaint_full_res"] = 1 + + if context.model_paths.get("lora"): + lora_model = context.model_paths["lora"] + lora_model = lora_model if isinstance(lora_model, list) else [lora_model] + lora_alpha = lora_alpha if isinstance(lora_alpha, list) else [lora_alpha] + + for lora, alpha in zip(lora_model, lora_alpha): + lora = os.path.basename(lora) + lora = os.path.splitext(lora)[0] + cmd["prompt"] += f" " + + if controlnet_filter and control_image and context.model_paths.get("controlnet"): + controlnet_model = context.model_paths["controlnet"] + + model_hash = auto1111_hash(controlnet_model) + controlnet_model = os.path.basename(controlnet_model) + controlnet_model = os.path.splitext(controlnet_model)[0] + print(f"setting controlnet model: {controlnet_model}") + controlnet_model = f"{controlnet_model} [{model_hash}]" + + cmd["alwayson_scripts"] = { + "controlnet": { + "args": [ + { + "image": control_image, + "weight": control_alpha, + "module": controlnet_filter, + "model": controlnet_model, + "resize_mode": "Crop and Resize", + "threshold_a": 50, + "threshold_b": 130, + } + ] + } + } + + operation_to_apply = "img2img" if init_image else "txt2img" + + stream_image_progress = webui_opts.get("live_previews_enable", False) + + progress_thread = Thread( + target=image_progress_thread, args=(task_id, callback, stream_image_progress, num_outputs, num_inference_steps) + ) + progress_thread.start() + + print(f"task id: {task_id}") + print_request(operation_to_apply, cmd) + + res = webui_post(f"/sdapi/v1/{operation_to_apply}", json=cmd) + if res.status_code == 200: + res = res.json() + else: + raise Exception( + "The engine failed while generating this image. Please check the logs in the command-line window for more details." + ) + + import json + + print(json.loads(res["info"])["infotexts"]) + + images = res["images"] + if output_type == "pil": + images = [base64_str_to_img(img) for img in images] + elif output_type == "base64": + images = [base64_buffer_to_base64_img(img) for img in images] + + return images + + +def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"): + """ + * context: Context + * images: str or PIL.Image or list of str/PIL.Image - image to filter. if a string is passed, it needs to be a base64-encoded image + * filters: filter_type (string) or list of strings + * filter_params: dict + + returns: [PIL.Image] - list of filtered images + """ + images = images if isinstance(images, list) else [images] + filters = filters if isinstance(filters, list) else [filters] + + if "nsfw_checker" in filters: + filters.remove("nsfw_checker") # handled by ED directly + + args = {} + controlnet_filters = [] + + print(filter_params) + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + if filter_name == "gfpgan": + args["gfpgan_visibility"] = 1 + + if filter_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + args["upscaler_1"] = params.get("upscaler", "RealESRGAN_x4plus") + args["upscaling_resize"] = params.get("scale", 4) + + if args["upscaler_1"] == "RealESRGAN_x4plus": + args["upscaler_1"] = "R-ESRGAN 4x+" + elif args["upscaler_1"] == "RealESRGAN_x4plus_anime_6B": + args["upscaler_1"] = "R-ESRGAN 4x+ Anime6B" + + if filter_name == "codeformer": + args["codeformer_visibility"] = 1 + args["codeformer_weight"] = params.get("codeformer_fidelity", 0.5) + + if filter_name.startswith("controlnet_"): + filter_name = convert_ED_controlnet_filter_name(filter_name) + controlnet_filters.append(filter_name) + + print(f"filtering {len(images)} images with {args}. {controlnet_filters=}") + + if len(filters) > len(controlnet_filters): + filtered_images = extra_batch_images(images, input_type=input_type, **args) + else: + filtered_images = images + + for filter_name in controlnet_filters: + filtered_images = controlnet_filter(filtered_images, module=filter_name, input_type=input_type) + + return filtered_images + + +def get_url(): + return f"//{WEBUI_HOST}:{WEBUI_PORT}/?__theme=dark" + + +def stop_rendering(context): + try: + res = webui_post("/sdapi/v1/interrupt") + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + print(f"Error interrupting webui: {e}") + + +def refresh_models(): + def make_refresh_call(type): + try: + webui_post(f"/sdapi/v1/refresh-{type}") + except: + pass + + try: + for type in ("checkpoints", "vae"): + t = Thread(target=make_refresh_call, args=(type,)) + t.start() + except Exception as e: + print(f"Error refreshing models: {e}") + + +def list_controlnet_filters(): + return [ + "openpose", + "openpose_face", + "openpose_faceonly", + "openpose_hand", + "openpose_full", + "animal_openpose", + "densepose_parula (black bg & blue torso)", + "densepose (pruple bg & purple torso)", + "dw_openpose_full", + "mediapipe_face", + "instant_id_face_keypoints", + "InsightFace+CLIP-H (IPAdapter)", + "InsightFace (InstantID)", + "canny", + "mlsd", + "scribble_hed", + "scribble_hedsafe", + "scribble_pidinet", + "scribble_pidsafe", + "scribble_xdog", + "softedge_hed", + "softedge_hedsafe", + "softedge_pidinet", + "softedge_pidsafe", + "softedge_teed", + "normal_bae", + "depth_midas", + "normal_midas", + "depth_zoe", + "depth_leres", + "depth_leres++", + "depth_anything_v2", + "depth_anything", + "depth_hand_refiner", + "depth_marigold", + "lineart_coarse", + "lineart_realistic", + "lineart_anime", + "lineart_standard (from white bg & black line)", + "lineart_anime_denoise", + "reference_adain", + "reference_only", + "reference_adain+attn", + "tile_colorfix", + "tile_resample", + "tile_colorfix+sharp", + "CLIP-ViT-H (IPAdapter)", + "CLIP-G (Revision)", + "CLIP-G (Revision ignore prompt)", + "CLIP-ViT-bigG (IPAdapter)", + "InsightFace+CLIP-H (IPAdapter)", + "inpaint_only", + "inpaint_only+lama", + "inpaint_global_harmonious", + "seg_ufade20k", + "seg_ofade20k", + "seg_anime_face", + "seg_ofcoco", + "shuffle", + "segment", + "invert (from white bg & black line)", + "threshold", + "t2ia_sketch_pidi", + "t2ia_color_grid", + "recolor_intensity", + "recolor_luminance", + "blur_gaussian", + ] + + +def controlnet_filter(images, module="none", processor_res=512, threshold_a=64, threshold_b=64, input_type="pil"): + if input_type == "pil": + images = [img_to_base64_str(x) for x in images] + + payload = { + "controlnet_module": module, + "controlnet_input_images": images, + "controlnet_processor_res": processor_res, + "controlnet_threshold_a": threshold_a, + "controlnet_threshold_b": threshold_b, + } + res = webui_post("/controlnet/detect", json=payload) + res = res.json() + filtered_images = res["images"] + + if input_type == "pil": + filtered_images = [base64_str_to_img(img) for img in filtered_images] + elif input_type == "base64": + filtered_images = [base64_buffer_to_base64_img(img) for img in filtered_images] + + return filtered_images + + +def image_progress_thread(task_id, callback, stream_image_progress, total_images, total_steps): + from PIL import Image + + last_preview_id = -1 + + EMPTY_IMAGE = Image.new("RGB", (1, 1)) + + while True: + res = webui_post( + f"/internal/progress", + json={"id_task": task_id, "live_preview": stream_image_progress, "id_live_preview": last_preview_id}, + ) + if res.status_code == 200: + res = res.json() + + last_preview_id = res["id_live_preview"] + + if res["progress"] is not None: + step_num = int(res["progress"] * total_steps) + + if res["live_preview"] is not None: + img = res["live_preview"] + img = base64_str_to_img(img) + images = [EMPTY_IMAGE] * total_images + images[0] = img + else: + images = None + + callback(images, step_num) + + if res["completed"] == True: + print("Complete!") + break + + time.sleep(0.5) + + +def webui_get(uri, *args, **kwargs): + url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}" + return requests.get(url, *args, **kwargs) + + +def webui_post(uri, *args, **kwargs): + url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}" + return requests.post(url, *args, **kwargs) + + +def print_request(operation_to_apply, args): + args = deepcopy(args) + if "init_images" in args: + args["init_images"] = ["img" for _ in args["init_images"]] + if "mask" in args: + args["mask"] = "mask_img" + + controlnet_args = args.get("alwayson_scripts", {}).get("controlnet", {}).get("args", []) + if controlnet_args: + controlnet_args[0]["image"] = "control_image" + + print(f"operation: {operation_to_apply}, args: {args}") + + +def auto1111_hash(file_path): + import hashlib + + with open(file_path, "rb") as f: + f.seek(0x100000) + b = f.read(0x10000) + return hashlib.sha256(b).hexdigest()[:8] + + +def extra_batch_images( + images, # list of PIL images + name_list=None, # list of image names + resize_mode=0, + show_extras_results=True, + gfpgan_visibility=0, + codeformer_visibility=0, + codeformer_weight=0, + upscaling_resize=2, + upscaling_resize_w=512, + upscaling_resize_h=512, + upscaling_crop=True, + upscaler_1="None", + upscaler_2="None", + extras_upscaler_2_visibility=0, + upscale_first=False, + use_async=False, + input_type="pil", +): + if name_list is not None: + if len(name_list) != len(images): + raise RuntimeError("len(images) != len(name_list)") + else: + name_list = [f"image{i + 1:05}" for i in range(len(images))] + + if input_type == "pil": + images = [img_to_base64_str(x) for x in images] + + image_list = [] + for name, image in zip(name_list, images): + image_list.append({"data": image, "name": name}) + + payload = { + "resize_mode": resize_mode, + "show_extras_results": show_extras_results, + "gfpgan_visibility": gfpgan_visibility, + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + "upscaling_resize": upscaling_resize, + "upscaling_resize_w": upscaling_resize_w, + "upscaling_resize_h": upscaling_resize_h, + "upscaling_crop": upscaling_crop, + "upscaler_1": upscaler_1, + "upscaler_2": upscaler_2, + "extras_upscaler_2_visibility": extras_upscaler_2_visibility, + "upscale_first": upscale_first, + "imageList": image_list, + } + + res = webui_post("/sdapi/v1/extra-batch-images", json=payload) + if res.status_code == 200: + res = res.json() + else: + raise Exception( + "The engine failed while filtering this image. Please check the logs in the command-line window for more details." + ) + + images = res["images"] + + if input_type == "pil": + images = [base64_str_to_img(img) for img in images] + elif input_type == "base64": + images = [base64_buffer_to_base64_img(img) for img in images] + + return images + + +def base64_buffer_to_base64_img(img): + output_format = webui_opts.get("samples_format", "jpeg") + mime_type = f"image/{output_format.lower()}" + return f"data:{mime_type};base64," + img + + +def convert_ED_sampler_names(sampler_name): + name_mapping = { + "dpmpp_2m": "DPM++ 2M", + "dpmpp_sde": "DPM++ SDE", + "dpmpp_2m_sde": "DPM++ 2M SDE", + "dpmpp_2m_sde_heun": "DPM++ 2M SDE Heun", + "dpmpp_2s_a": "DPM++ 2S a", + "dpmpp_3m_sde": "DPM++ 3M SDE", + "euler_a": "Euler a", + "euler": "Euler", + "lms": "LMS", + "heun": "Heun", + "dpm2": "DPM2", + "dpm2_a": "DPM2 a", + "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", + "restart": "Restart", + "heun_pp2": "HeunPP2", + "ipndm": "IPNDM", + "ipndm_v": "IPNDM_V", + "deis": "DEIS", + "ddim": "DDIM", + "ddim_cfgpp": "DDIM CFG++", + "plms": "PLMS", + "unipc": "UniPC", + "lcm": "LCM", + "ddpm": "DDPM", + "forge_flux_realistic": "[Forge] Flux Realistic", + "forge_flux_realistic_slow": "[Forge] Flux Realistic (Slow)", + # deprecated samplers in 3.5 + "dpm_solver_stability": None, + "unipc_snr": None, + "unipc_tu": None, + "unipc_snr_2": None, + "unipc_tu_2": None, + "unipc_tq": None, + } + return name_mapping.get(sampler_name) + + +def convert_ED_controlnet_filter_name(filter): + if filter is None: + return None + + def cn(n): + if n.startswith("controlnet_"): + return n[len("controlnet_") :] + return n + + mapping = { + "controlnet_scribble_hedsafe": None, + "controlnet_scribble_pidsafe": None, + "controlnet_softedge_pidsafe": "controlnet_softedge_pidisafe", + "controlnet_normal_bae": "controlnet_normalbae", + "controlnet_segment": None, + } + if isinstance(filter, list): + return [cn(mapping.get(f, f)) for f in filter] + return cn(mapping.get(filter, filter)) diff --git a/ui/easydiffusion/easydb/schemas.py b/ui/easydiffusion/easydb/schemas.py index 68bc04e2..c6b7337f 100644 --- a/ui/easydiffusion/easydb/schemas.py +++ b/ui/easydiffusion/easydb/schemas.py @@ -33,4 +33,3 @@ class Bucket(BucketBase): class Config: orm_mode = True - diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 26f05242..d821db41 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -8,7 +8,7 @@ from easydiffusion import app from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db +from sdkit.models import scan_model, download_model, get_model_info_from_db from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.utils import hash_file_quick from sdkit.models.model_loader.embeddings import get_embedding_token @@ -25,15 +25,15 @@ KNOWN_MODEL_TYPES = [ "controlnet", ] MODEL_EXTENSIONS = { - "stable-diffusion": [".ckpt", ".safetensors"], - "vae": [".vae.pt", ".ckpt", ".safetensors"], - "hypernetwork": [".pt", ".safetensors"], + "stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"], + "vae": [".vae.pt", ".ckpt", ".safetensors", ".sft"], + "hypernetwork": [".pt", ".safetensors", ".sft"], "gfpgan": [".pth"], "realesrgan": [".pth"], - "lora": [".ckpt", ".safetensors", ".pt"], + "lora": [".ckpt", ".safetensors", ".sft", ".pt"], "codeformer": [".pth"], - "embeddings": [".pt", ".bin", ".safetensors"], - "controlnet": [".pth", ".safetensors"], + "embeddings": [".pt", ".bin", ".safetensors", ".sft"], + "controlnet": [".pth", ".safetensors", ".sft"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -63,6 +63,7 @@ def init(): def load_default_models(context: Context): from easydiffusion import runtime + from easydiffusion.backend_manager import backend runtime.set_vram_optimizations(context) @@ -70,7 +71,7 @@ def load_default_models(context: Context): for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False) try: - load_model( + backend.load_model( context, model_type, scan_model=context.model_paths[model_type] != None @@ -92,8 +93,10 @@ def load_default_models(context: Context): def unload_all(context: Context): + from easydiffusion.backend_manager import backend + for model_type in KNOWN_MODEL_TYPES: - unload_model(context, model_type) + backend.unload_model(context, model_type) if model_type in context.model_load_errors: del context.model_load_errors[model_type] @@ -154,6 +157,8 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []): + from easydiffusion.backend_manager import backend + models_to_reload = { model_type: path for model_type, path in models_data.model_paths.items() @@ -175,7 +180,7 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req - action_fn = unload_model if context.model_paths[model_type] is None else load_model + action_fn = backend.unload_model if context.model_paths[model_type] is None else backend.load_model extra_params = models_data.model_params.get(model_type, {}) try: action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already @@ -183,14 +188,23 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models del context.model_load_errors[model_type] except Exception as e: log.exception(e) - if action_fn == load_model: + if action_fn == backend.load_model: context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks def resolve_model_paths(models_data: ModelsData): model_paths = models_data.model_paths + skip_models = cn_filters + [ + "latent_upscaler", + "nsfw_checker", + "esrgan_4x", + "lanczos", + "nearest", + "scunet", + "swinir", + ] + for model_type in model_paths: - skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"] if model_type in skip_models: # doesn't use model paths continue if model_type == "codeformer" and model_paths[model_type]: @@ -320,6 +334,10 @@ def is_malicious_model(file_path): def getModels(scan_for_malicious: bool = True): + from easydiffusion.backend_manager import backend + + backend.refresh_models() + models = { "options": { "stable-diffusion": [], @@ -329,19 +347,19 @@ def getModels(scan_for_malicious: bool = True): "codeformer": [{"codeformer": "CodeFormer"}], "embeddings": [], "controlnet": [ - {"control_v11p_sd15_canny": "Canny (*)"}, - {"control_v11p_sd15_openpose": "OpenPose (*)"}, - {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, - {"control_v11f1p_sd15_depth": "Depth (*)"}, - {"control_v11p_sd15_scribble": "Scribble"}, - {"control_v11p_sd15_softedge": "Soft Edge"}, - {"control_v11p_sd15_inpaint": "Inpaint"}, - {"control_v11p_sd15_lineart": "Line Art"}, - {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, - {"control_v11p_sd15_mlsd": "Straight Lines"}, - {"control_v11p_sd15_seg": "Segment"}, - {"control_v11e_sd15_shuffle": "Shuffle"}, - {"control_v11f1e_sd15_tile": "Tile"}, + # {"control_v11p_sd15_canny": "Canny (*)"}, + # {"control_v11p_sd15_openpose": "OpenPose (*)"}, + # {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, + # {"control_v11f1p_sd15_depth": "Depth (*)"}, + # {"control_v11p_sd15_scribble": "Scribble"}, + # {"control_v11p_sd15_softedge": "Soft Edge"}, + # {"control_v11p_sd15_inpaint": "Inpaint"}, + # {"control_v11p_sd15_lineart": "Line Art"}, + # {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, + # {"control_v11p_sd15_mlsd": "Straight Lines"}, + # {"control_v11p_sd15_seg": "Segment"}, + # {"control_v11e_sd15_shuffle": "Shuffle"}, + # {"control_v11f1e_sd15_tile": "Tile"}, ], }, } @@ -378,6 +396,8 @@ def getModels(scan_for_malicious: bool = True): model_id = entry.name[: -len(matching_suffix)] if callable(nameFilter): model_id = nameFilter(model_id) + if model_id is None: + continue model_exists = False for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models @@ -416,7 +436,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="stable-diffusion") listModels(model_type="vae") listModels(model_type="hypernetwork") - listModels(model_type="gfpgan") + listModels(model_type="gfpgan", nameFilter=lambda x: (x if "gfpgan" in x.lower() else None)) listModels(model_type="lora") listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="controlnet") diff --git a/ui/easydiffusion/runtime.py b/ui/easydiffusion/runtime.py index 78d90f60..accced00 100644 --- a/ui/easydiffusion/runtime.py +++ b/ui/easydiffusion/runtime.py @@ -1,4 +1,5 @@ """ +(OUTDATED DOC) A runtime that runs on a specific device (in a thread). It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context. @@ -6,42 +7,35 @@ It can run various tasks like image generation, image filtering, model merge etc This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function. """ -from easydiffusion import device_manager -from easydiffusion.utils import log -from sdkit import Context -from sdkit.utils import get_device_usage - -context = Context() # thread-local -""" -runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc -""" +context = None def init(device): """ Initializes the fields that will be bound to this runtime's context, and sets the current torch device """ + + global context + + from easydiffusion import device_manager + from easydiffusion.backend_manager import backend + from easydiffusion.app import getConfig + + context = backend.create_context() + context.stop_processing = False context.temp_images = {} context.partial_x_samples = None context.model_load_errors = {} context.enable_codeformer = True - from easydiffusion import app - - app_config = app.getConfig() - context.test_diffusers = app_config.get("use_v3_engine", True) - - log.info("Device usage during initialization:") - get_device_usage(device, log_info=True, process_usage_only=False) - device_manager.device_init(context, device) -def set_vram_optimizations(context: Context): - from easydiffusion import app +def set_vram_optimizations(context): + from easydiffusion.app import getConfig - config = app.getConfig() + config = getConfig() vram_usage_level = config.get("vram_usage_level", "balanced") if vram_usage_level != context.vram_usage_level: diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index a251ede6..ca7dc98e 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -2,6 +2,7 @@ Notes: async endpoints always run on the main thread. Without they run on the thread pool. """ + import datetime import mimetypes import os @@ -20,6 +21,7 @@ from easydiffusion.types import ( OutputFormatData, SaveToDiskData, convert_legacy_render_req_to_new, + convert_legacy_controlnet_filter_name, ) from easydiffusion.utils import log from fastapi import FastAPI, HTTPException @@ -67,6 +69,7 @@ class SetAppConfigRequest(BaseModel, extra=Extra.allow): listen_to_network: bool = None listen_port: int = None use_v3_engine: bool = True + backend: str = "ed_diffusers" models_dir: str = None @@ -155,6 +158,12 @@ def init(): def shutdown_event(): # Signal render thread to close on shutdown task_manager.current_state_error = SystemExit("Application shutting down.") + @server_api.on_event("startup") + def start_event(): + from easydiffusion.app import open_browser + + open_browser() + # API implementations def set_app_config_internal(req: SetAppConfigRequest): @@ -176,7 +185,8 @@ def set_app_config_internal(req: SetAppConfigRequest): config["net"] = {} config["net"]["listen_port"] = int(req.listen_port) - config["use_v3_engine"] = req.use_v3_engine + config["use_v3_engine"] = req.backend == "ed_diffusers" + config["backend"] = req.backend config["models_dir"] = req.models_dir for property, property_value in req.dict().items(): @@ -216,6 +226,8 @@ def read_web_data_internal(key: str = None, **kwargs): return JSONResponse(config, headers=NOCACHE_HEADERS) elif key == "system_info": + from easydiffusion.backend_manager import backend + config = app.getConfig() output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) @@ -226,6 +238,7 @@ def read_web_data_internal(key: str = None, **kwargs): "default_output_dir": output_dir, "enforce_output_dir": ("force_save_path" in config), "enforce_output_metadata": ("force_save_metadata" in config), + "backend_url": backend.get_url(), } system_info["devices"]["config"] = config.get("render_devices", "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) @@ -309,6 +322,15 @@ def filter_internal(req: dict): output_format: OutputFormatData = OutputFormatData.parse_obj(req) save_data: SaveToDiskData = SaveToDiskData.parse_obj(req) + filter_req.filter = convert_legacy_controlnet_filter_name(filter_req.filter) + + for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if models_data.model_paths.get(model_name): + if model_name not in filter_req.filter_params: + filter_req.filter_params[model_name] = {} + + filter_req.filter_params[model_name]["upscaler"] = models_data.model_paths[model_name] + # enqueue the task task = FilterTask(filter_req, task_data, models_data, output_format, save_data) return enqueue_task(task) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 699b4494..5ad6420d 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -4,6 +4,7 @@ Notes: Use weak_thread_data to store all other data using weak keys. This will allow for garbage collection after the thread dies. """ + import json import traceback @@ -19,7 +20,6 @@ import torch from easydiffusion import device_manager from easydiffusion.tasks import Task from easydiffusion.utils import log -from sdkit.utils import gc THREAD_NAME_PREFIX = "" ERR_LOCK_FAILED = " failed to acquire lock within timeout." @@ -233,6 +233,7 @@ def thread_render(device): global current_state, current_state_error from easydiffusion import model_manager, runtime + from easydiffusion.backend_manager import backend try: runtime.init(device) @@ -244,8 +245,17 @@ def thread_render(device): } current_state = ServerStates.LoadingModel - model_manager.load_default_models(runtime.context) + while True: + try: + if backend.ping(timeout=1): + break + + time.sleep(1) + except TimeoutError: + time.sleep(1) + + model_manager.load_default_models(runtime.context) current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) @@ -291,7 +301,6 @@ def thread_render(device): task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) finally: - gc(runtime.context) task.lock.release() keep_task_alive(task) diff --git a/ui/easydiffusion/tasks/filter_images.py b/ui/easydiffusion/tasks/filter_images.py index 7d3d1326..67572a8a 100644 --- a/ui/easydiffusion/tasks/filter_images.py +++ b/ui/easydiffusion/tasks/filter_images.py @@ -5,9 +5,7 @@ import time from numpy import base_repr -from sdkit.filter import apply_filters -from sdkit.models import load_model -from sdkit.utils import img_to_base64_str, get_image, log, save_images +from sdkit.utils import img_to_base64_str, log, save_images, base64_str_to_img from easydiffusion import model_manager, runtime from easydiffusion.types import ( @@ -19,6 +17,7 @@ from easydiffusion.types import ( TaskData, GenerateImageRequest, ) +from easydiffusion.utils import filter_nsfw from easydiffusion.utils.save_utils import format_folder_name from .task import Task @@ -47,7 +46,9 @@ class FilterTask(Task): # convert to multi-filter format, if necessary if isinstance(req.filter, str): - req.filter_params = {req.filter: req.filter_params} + if req.filter not in req.filter_params: + req.filter_params = {req.filter: req.filter_params} + req.filter = [req.filter] if not isinstance(req.image, list): @@ -57,6 +58,7 @@ class FilterTask(Task): "Runs the image filtering task on the assigned thread" from easydiffusion import app + from easydiffusion.backend_manager import backend context = runtime.context @@ -66,15 +68,24 @@ class FilterTask(Task): print_task_info(self.request, self.models_data, self.output_format, self.save_data) - if isinstance(self.request.image, list): - images = [get_image(img) for img in self.request.image] - else: - images = get_image(self.request.image) - - images = filter_images(context, images, self.request.filter, self.request.filter_params) + has_nsfw_filter = "nsfw_filter" in self.request.filter output_format = self.output_format + backend.set_options( + context, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + ) + + images = backend.filter_images( + context, self.request.image, self.request.filter, self.request.filter_params, input_type="base64" + ) + + if has_nsfw_filter: + images = filter_nsfw(images) + if self.save_data.save_to_disk_path is not None: app_config = app.getConfig() folder_format = app_config.get("folder_format", "$id") @@ -85,8 +96,9 @@ class FilterTask(Task): save_dir_path = os.path.join( self.save_data.save_to_disk_path, format_folder_name(folder_format, dummy_req, self.task_data) ) + images_pil = [base64_str_to_img(img) for img in images] save_images( - images, + images_pil, save_dir_path, file_name=img_id, output_format=output_format.output_format, @@ -94,13 +106,6 @@ class FilterTask(Task): output_lossless=output_format.output_lossless, ) - images = [ - img_to_base64_str( - img, output_format.output_format, output_format.output_quality, output_format.output_lossless - ) - for img in images - ] - res = FilterImageResponse(self.request, self.models_data, images=images) res = res.json() self.buffer_queue.put(json.dumps(res)) @@ -110,46 +115,6 @@ class FilterTask(Task): self.response = res -def filter_images(context, images, filters, filter_params={}): - filters = filters if isinstance(filters, list) else [filters] - - for filter_name in filters: - params = filter_params.get(filter_name, {}) - - previous_state = before_filter(context, filter_name, params) - - try: - images = apply_filters(context, filter_name, images, **params) - finally: - after_filter(context, filter_name, params, previous_state) - - return images - - -def before_filter(context, filter_name, filter_params): - if filter_name == "codeformer": - from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use - - default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] - prev_realesrgan_path = None - - upscale_faces = filter_params.get("upscale_faces", False) - if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: - prev_realesrgan_path = context.model_paths.get("realesrgan") - context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") - load_model(context, "realesrgan") - - return prev_realesrgan_path - - -def after_filter(context, filter_name, filter_params, previous_state): - if filter_name == "codeformer": - prev_realesrgan_path = previous_state - if prev_realesrgan_path: - context.model_paths["realesrgan"] = prev_realesrgan_path - load_model(context, "realesrgan") - - def print_task_info( req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData ): diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index 7d8fbc5e..2a46e6e6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -2,26 +2,23 @@ import json import pprint import queue import time +from PIL import Image from easydiffusion import model_manager, runtime from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData, SaveToDiskData from easydiffusion.types import Image as ResponseImage -from easydiffusion.types import GenerateImageResponse, RenderTaskData, UserInitiatedStop -from easydiffusion.utils import get_printable_request, log, save_images_to_disk -from sdkit.generate import generate_images +from easydiffusion.types import GenerateImageResponse, RenderTaskData +from easydiffusion.utils import get_printable_request, log, save_images_to_disk, filter_nsfw from sdkit.utils import ( - diffusers_latent_samples_to_images, - gc, img_to_base64_str, + base64_str_to_img, img_to_buffer, - latent_samples_to_images, resize_img, get_image, log, ) from .task import Task -from .filter_images import filter_images class RenderTask(Task): @@ -51,15 +48,13 @@ class RenderTask(Task): "Runs the image generation task on the assigned thread" from easydiffusion import task_manager, app + from easydiffusion.backend_manager import backend context = runtime.context config = app.getConfig() if config.get("block_nsfw", False): # override if set on the server self.task_data.block_nsfw = True - if "nsfw_checker" not in self.task_data.filters: - self.task_data.filters.append("nsfw_checker") - self.models_data.model_paths["nsfw_checker"] = "nsfw_checker" def step_callback(): task_manager.keep_task_alive(self) @@ -68,7 +63,7 @@ class RenderTask(Task): if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance( self.error, StopAsyncIteration ): - context.stop_processing = True + backend.stop_rendering(context) if isinstance(task_manager.current_state_error, StopAsyncIteration): self.error = task_manager.current_state_error task_manager.current_state_error = None @@ -78,11 +73,7 @@ class RenderTask(Task): model_manager.resolve_model_paths(self.models_data) models_to_force_reload = [] - if ( - runtime.set_vram_optimizations(context) - or self.has_param_changed(context, "clip_skip") - or self.trt_needs_reload(context) - ): + if runtime.set_vram_optimizations(context) or self.has_param_changed(context, "clip_skip"): models_to_force_reload.append("stable-diffusion") model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload) @@ -99,10 +90,11 @@ class RenderTask(Task): self.buffer_queue, self.temp_images, step_callback, + self, ) def has_param_changed(self, context, param_name): - if not context.test_diffusers: + if not getattr(context, "test_diffusers", False): return False if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: return True @@ -111,29 +103,6 @@ class RenderTask(Task): new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) return model["params"].get(param_name) != new_val - def trt_needs_reload(self, context): - if not context.test_diffusers: - return False - if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: - return True - - model = context.models["stable-diffusion"] - - # curr_convert_to_trt = model["params"].get("convert_to_tensorrt") - new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False) - - pipe = model["default"] - is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr( - pipe.unet, "_allocate_trt_buffers_backup" - ) - if new_convert_to_trt and not is_trt_loaded: - return True - - curr_build_config = model["params"].get("trt_build_config") - new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {}) - - return new_convert_to_trt and curr_build_config != new_build_config - def make_images( context, @@ -145,12 +114,21 @@ def make_images( data_queue: queue.Queue, task_temp_images: list, step_callback, + task, ): - context.stop_processing = False print_task_info(req, task_data, models_data, output_format, save_data) images, seeds = make_images_internal( - context, req, task_data, models_data, output_format, save_data, data_queue, task_temp_images, step_callback + context, + req, + task_data, + models_data, + output_format, + save_data, + data_queue, + task_temp_images, + step_callback, + task, ) res = GenerateImageResponse( @@ -170,7 +148,9 @@ def print_task_info( output_format: OutputFormatData, save_data: SaveToDiskData, ): - req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[") + req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace( + "[", "\[" + ) task_str = pprint.pformat(task_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[") @@ -178,7 +158,7 @@ def print_task_info( log.info(f"request: {req_str}") log.info(f"task data: {task_str}") - # log.info(f"models data: {models_data}") + log.info(f"models data: {models_data}") log.info(f"output format: {output_format}") log.info(f"save data: {save_data}") @@ -193,26 +173,41 @@ def make_images_internal( data_queue: queue.Queue, task_temp_images: list, step_callback, + task, ): - images, user_stopped = generate_images_internal( + from easydiffusion.backend_manager import backend + + # prep the nsfw_filter + if task_data.block_nsfw: + filter_nsfw([Image.new("RGB", (1, 1))]) # hack - ensures that the model is available + + images = generate_images_internal( context, req, task_data, models_data, + output_format, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval, ) - - gc(context) + user_stopped = isinstance(task.error, StopAsyncIteration) filters, filter_params = task_data.filters, task_data.filter_params - filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images + if len(filters) > 0 and not user_stopped: + filtered_images = backend.filter_images(context, images, filters, filter_params, input_type="base64") + else: + filtered_images = images + + if task_data.block_nsfw: + filtered_images = filter_nsfw(filtered_images) if save_data.save_to_disk_path is not None: - save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data) + images_pil = [base64_str_to_img(img) for img in images] + filtered_images_pil = [base64_str_to_img(img) for img in filtered_images] + save_images_to_disk(images_pil, filtered_images_pil, req, task_data, models_data, output_format, save_data) seeds = [*range(req.seed, req.seed + len(images))] if task_data.show_only_filtered_image or filtered_images is images: @@ -226,97 +221,43 @@ def generate_images_internal( req: GenerateImageRequest, task_data: RenderTaskData, models_data: ModelsData, + output_format: OutputFormatData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool, stream_image_progress_interval: int, ): - context.temp_images.clear() + from easydiffusion.backend_manager import backend - callback = make_step_callback( + callback = make_step_callback(context, req, task_data, data_queue, task_temp_images, step_callback) + + req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 + + if req.control_image and task_data.control_filter_to_apply: + req.controlnet_filter = task_data.control_filter_to_apply + + if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: + req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 + + backend.set_options( context, - req, - task_data, - data_queue, - task_temp_images, - step_callback, - stream_image_progress, - stream_image_progress_interval, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + vae_tiling=task_data.enable_vae_tiling, + stream_image_progress=stream_image_progress, + stream_image_progress_interval=stream_image_progress_interval, + clip_skip=2 if task_data.clip_skip else 1, ) - try: - if req.init_image is not None and not context.test_diffusers: - req.sampler_name = "ddim" + images = backend.generate_images(context, callback=callback, output_type="base64", **req.dict()) - req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 - - if req.control_image and task_data.control_filter_to_apply: - req.control_image = get_image(req.control_image) - req.control_image = resize_img(req.control_image.convert("RGB"), req.width, req.height, clamp_to_8=True) - req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0] - - if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: - req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 - - if context.test_diffusers: - pipe = context.models["stable-diffusion"]["default"] - if hasattr(pipe.unet, "_allocate_trt_buffers_backup"): - setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup) - delattr(pipe.unet, "_allocate_trt_buffers_backup") - - if hasattr(pipe.unet, "_allocate_trt_buffers"): - convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False) - if convert_to_trt: - pipe.unet.forward = pipe.unet._trt_forward - # pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward - log.info(f"Setting unet.forward to TensorRT") - else: - log.info(f"Not using TensorRT for unet.forward") - pipe.unet.forward = pipe.unet._non_trt_forward - # pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward - setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers) - delattr(pipe.unet, "_allocate_trt_buffers") - - if task_data.enable_vae_tiling: - if hasattr(pipe, "enable_vae_tiling"): - pipe.enable_vae_tiling() - else: - if hasattr(pipe, "disable_vae_tiling"): - pipe.disable_vae_tiling() - - images = generate_images(context, callback=callback, **req.dict()) - user_stopped = False - except UserInitiatedStop: - images = [] - user_stopped = True - if context.partial_x_samples is not None: - if context.test_diffusers: - images = diffusers_latent_samples_to_images(context, context.partial_x_samples) - else: - images = latent_samples_to_images(context, context.partial_x_samples) - finally: - if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: - if not context.test_diffusers: - del context.partial_x_samples - context.partial_x_samples = None - - return images, user_stopped + return images def construct_response(images: list, seeds: list, output_format: OutputFormatData): - return [ - ResponseImage( - data=img_to_base64_str( - img, - output_format.output_format, - output_format.output_quality, - output_format.output_lossless, - ), - seed=seed, - ) - for img, seed in zip(images, seeds) - ] + return [ResponseImage(data=img, seed=seed) for img, seed in zip(images, seeds)] def make_step_callback( @@ -326,53 +267,44 @@ def make_step_callback( data_queue: queue.Queue, task_temp_images: list, step_callback, - stream_image_progress: bool, - stream_image_progress_interval: int, ): + from easydiffusion.backend_manager import backend + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 - def update_temp_img(x_samples, task_temp_images: list): + def update_temp_img(images, task_temp_images: list): partial_images = [] - if context.test_diffusers: - images = diffusers_latent_samples_to_images(context, x_samples) - else: - images = latent_samples_to_images(context, x_samples) + if images is None: + return [] if task_data.block_nsfw: - images = filter_images(context, images, "nsfw_checker") + images = filter_nsfw(images, print_log=False) for i, img in enumerate(images): + img = img.convert("RGB") + img = resize_img(img, req.width, req.height) buf = img_to_buffer(img, output_format="JPEG") - context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"}) del images return partial_images - def on_image_step(x_samples, i, *args): + def on_image_step(images, i, *args): nonlocal last_callback_time - if context.test_diffusers: - context.partial_x_samples = (x_samples, args[0]) - else: - context.partial_x_samples = x_samples - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() progress = {"step": i, "step_time": step_time, "total_steps": n_steps} - if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: - progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images) + if images is not None: + progress["output"] = update_temp_img(images, task_temp_images) data_queue.put(json.dumps(progress)) step_callback() - if context.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") - return on_image_step diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index f16e9aa3..93bfe08f 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -19,9 +19,10 @@ class GenerateImageRequest(BaseModel): init_image_mask: Any = None control_image: Any = None control_alpha: Union[float, List[float]] = None + controlnet_filter: str = None prompt_strength: float = 0.8 - preserve_init_image_color_profile = False - strict_mask_border = False + preserve_init_image_color_profile: bool = False + strict_mask_border: bool = False sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 @@ -100,7 +101,7 @@ class MergeRequest(BaseModel): model1: str = None ratio: float = None out_path: str = "mix" - use_fp16 = True + use_fp16: bool = True class Image: @@ -213,22 +214,19 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["controlnet"] = old_req.get("use_controlnet_model") model_paths["embeddings"] = old_req.get("use_embeddings_model") - model_paths["gfpgan"] = old_req.get("use_face_correction", "") - model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None + ## ensure that the model name is in the model path + for model_name in ("gfpgan", "codeformer"): + model_paths[model_name] = old_req.get("use_face_correction", "") + model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None - model_paths["codeformer"] = old_req.get("use_face_correction", "") - model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None + for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + model_paths[model_name] = old_req.get("use_upscale", "") + model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None - model_paths["realesrgan"] = old_req.get("use_upscale", "") - model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None - - model_paths["latent_upscaler"] = old_req.get("use_upscale", "") - model_paths["latent_upscaler"] = ( - model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None - ) if "control_filter_to_apply" in old_req: filter_model = old_req["control_filter_to_apply"] model_paths[filter_model] = filter_model + old_req["control_filter_to_apply"] = convert_legacy_controlnet_filter_name(old_req["control_filter_to_apply"]) if old_req.get("block_nsfw"): model_paths["nsfw_checker"] = "nsfw_checker" @@ -244,8 +242,12 @@ def convert_legacy_render_req_to_new(old_req: dict): } # move the filter params - if model_paths["realesrgan"]: - filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))} + for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if model_paths[model_name]: + filter_params[model_name] = { + "upscaler": model_paths[model_name], + "scale": int(old_req.get("upscale_amount", 4)), + } if model_paths["latent_upscaler"]: filter_params["latent_upscaler"] = { "prompt": old_req["prompt"], @@ -264,14 +266,31 @@ def convert_legacy_render_req_to_new(old_req: dict): if old_req.get("block_nsfw"): filters.append("nsfw_checker") - if model_paths["codeformer"]: - filters.append("codeformer") - elif model_paths["gfpgan"]: - filters.append("gfpgan") + for model_name in ("gfpgan", "codeformer"): + if model_paths[model_name]: + filters.append(model_name) + break - if model_paths["realesrgan"]: - filters.append("realesrgan") - elif model_paths["latent_upscaler"]: - filters.append("latent_upscaler") + for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if model_paths[model_name]: + filters.append(model_name) + break return new_req + + +def convert_legacy_controlnet_filter_name(filter): + from easydiffusion.backend_manager import backend + + if filter is None: + return None + + controlnet_filter_names = backend.list_controlnet_filters() + + def apply(f): + return f"controlnet_{f}" if f in controlnet_filter_names else f + + if isinstance(filter, list): + return [apply(f) for f in filter] + + return apply(filter) diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py index a930725b..3be08e79 100644 --- a/ui/easydiffusion/utils/__init__.py +++ b/ui/easydiffusion/utils/__init__.py @@ -7,6 +7,8 @@ from .save_utils import ( save_images_to_disk, get_printable_request, ) +from .nsfw_checker import filter_nsfw + def sha256sum(filename): sha256 = hashlib.sha256() @@ -18,4 +20,3 @@ def sha256sum(filename): sha256.update(data) return sha256.hexdigest() - diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py new file mode 100644 index 00000000..3790cacc --- /dev/null +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -0,0 +1,80 @@ +# possibly move this to sdkit in the future +import os + +# mirror of https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/model_quantized.onnx +NSFW_MODEL_URL = ( + "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/vit-base-nsfw-detector-quantized.onnx" +) +MODEL_HASH_QUICK = "220123559305b1b07b7a0894c3471e34dccd090d71cdf337dd8012f9e40d6c28" + +nsfw_check_model = None + + +def filter_nsfw(images, blur_radius: float = 75, print_log=True): + global nsfw_check_model + + from easydiffusion.app import MODELS_DIR + from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick + + import onnxruntime as ort + from PIL import ImageFilter + import numpy as np + + if nsfw_check_model is None: + model_dir = os.path.join(MODELS_DIR, "nsfw-checker") + model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") + + os.makedirs(model_dir, exist_ok=True) + + if not os.path.exists(model_path) or hash_file_quick(model_path) != MODEL_HASH_QUICK: + download_file(NSFW_MODEL_URL, model_path) + + nsfw_check_model = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + # Preprocess the input image + def preprocess_image(img): + img = img.convert("RGB") + + # config based on based on https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/preprocessor_config.json + # Resize the image + img = img.resize((384, 384)) + + # Normalize the image + img = np.array(img) / 255.0 # Scale pixel values to [0, 1] + mean = np.array([0.5, 0.5, 0.5]) + std = np.array([0.5, 0.5, 0.5]) + img = (img - mean) / std + + # Transpose to match input shape (batch_size, channels, height, width) + img = np.transpose(img, (2, 0, 1)).astype(np.float32) + + # Add batch dimension + img = np.expand_dims(img, axis=0) + + return img + + # Run inference + input_name = nsfw_check_model.get_inputs()[0].name + output_name = nsfw_check_model.get_outputs()[0].name + + if print_log: + log.info("Running NSFW checker (onnx)") + + results = [] + for img in images: + is_base64 = isinstance(img, str) + + input_img = base64_str_to_img(img) if is_base64 else img + + result = nsfw_check_model.run([output_name], {input_name: preprocess_image(input_img)}) + is_nsfw = [np.argmax(arr) == 1 for arr in result][0] + + if is_nsfw: + output_img = input_img.filter(ImageFilter.GaussianBlur(blur_radius)) + output_img = img_to_base64_str(output_img) if is_base64 else output_img + else: + output_img = img + + results.append(output_img) + + return results diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index e1b4d79a..216ec899 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -247,7 +247,7 @@ def get_printable_request( task_data_metadata.update(save_data.dict()) app_config = app.getConfig() - using_diffusers = app_config.get("use_v3_engine", True) + using_diffusers = app_config.get("backend", "ed_diffusers") in ("ed_diffusers", "webui") # Save the metadata in the order defined in TASK_TEXT_MAPPING metadata = {} diff --git a/ui/index.html b/ui/index.html index a50c995d..2fd11a5c 100644 --- a/ui/index.html +++ b/ui/index.html @@ -35,7 +35,10 @@

Easy Diffusion - v3.0.9 + + v3.5.0 + +

@@ -73,7 +76,7 @@
- +
@@ -83,7 +86,7 @@ Click to learn more about Negative Prompts (optional) - +
@@ -174,14 +177,14 @@ - + Click to learn more about Clip Skip - +
@@ -201,40 +204,92 @@ + + + + + + + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + +
-
- +
@@ -248,27 +303,38 @@ - Click to learn more about samplers + Click to learn more about samplers
- + @@ -357,14 +423,14 @@
- + - +
- + - + @@ -405,7 +471,7 @@
- +
  • @@ -418,7 +484,13 @@ @@ -825,7 +897,8 @@

    This license of this software forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm,
    spread misinformation and target vulnerable groups. For the full list of restrictions please read the license.

    By using this software, you consent to the terms and conditions of the license.

    - + + diff --git a/ui/main.py b/ui/main.py index 0239c829..1c4056a3 100644 --- a/ui/main.py +++ b/ui/main.py @@ -9,6 +9,3 @@ server.init() model_manager.init() app.init_render_threads() bucket_manager.init() - -# start the browser ui -app.open_browser() diff --git a/ui/media/css/auto-save.css b/ui/media/css/auto-save.css index 94790740..021f06d5 100644 --- a/ui/media/css/auto-save.css +++ b/ui/media/css/auto-save.css @@ -79,6 +79,7 @@ } .parameters-table .fa-fire, -.parameters-table .fa-bolt { +.parameters-table .fa-bolt, +.parameters-table .fa-robot { color: #F7630C; } diff --git a/ui/media/css/main.css b/ui/media/css/main.css index 009c13d5..c5e862fe 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -36,6 +36,15 @@ code { transform: translateY(4px); cursor: pointer; } +#engine-logo { + font-size: 8pt; + padding-left: 10pt; + color: var(--small-label-color); +} +#engine-logo a { + text-decoration: none; + /* color: var(--small-label-color); */ +} #prompt { width: 100%; height: 65pt; @@ -541,7 +550,7 @@ div.img-preview img { position: relative; background: var(--background-color4); display: flex; - padding: 12px 0 0; + padding: 6px 0 0; } .tab .icon { padding-right: 4pt; @@ -657,6 +666,10 @@ div.img-preview img { display: block; } +.gated-feature { + display: none; +} + .display-settings { float: right; position: relative; diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 1cbd3288..2b90a6d3 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -981,7 +981,20 @@ function onRedoFilter(req, img, e, tools) { function onUpscaleClick(req, img, e, tools) { let path = upscaleModelField.value let scale = parseInt(upscaleAmountField.value) - let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler" + + let filterName = null + const FILTERS = ["realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"] + for (let idx in FILTERS) { + let f = FILTERS[idx] + if (path.toLowerCase().includes(f)) { + filterName = f + break + } + } + + if (!filterName) { + return + } let statusText = "Upscaling by " + scale + "x using " + filterName applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools) } diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 97b7a96d..a597b281 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -249,14 +249,19 @@ var PARAMETERS = [ default: false, }, { - id: "use_v3_engine", - type: ParameterType.checkbox, - label: "Use the new v3 engine (diffusers)", + id: "backend", + type: ParameterType.select, + label: "Engine to use", note: - "Use our new v3 engine, with additional features like LoRA, ControlNet, SDXL, Embeddings, Tiling and lots more! Please press Save, then restart the program after changing this.", - icon: "fa-bolt", - default: true, + "Use our new v3.5 engine (Forge), with additional features like Flux, SD3, Lycoris and lots more! Please press Save, then restart the program after changing this.", + icon: "fa-robot", saveInAppConfig: true, + default: "ed_diffusers", + options: [ + { value: "webui", label: "v3.5 (latest)" }, + { value: "ed_diffusers", label: "v3.0" }, + { value: "ed_classic", label: "v2.0" }, + ], }, { id: "cloudflare", @@ -432,6 +437,7 @@ let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") let testDiffusers = document.querySelector("#use_v3_engine") +let backendEngine = document.querySelector("#backend") let profileNameField = document.querySelector("#profileName") let modelsDirField = document.querySelector("#models_dir") @@ -454,6 +460,23 @@ async function changeAppConfig(configDelta) { } } +function getDefaultDisplay(element) { + const tag = element.tagName.toLowerCase(); + const defaultDisplays = { + div: 'block', + span: 'inline', + p: 'block', + tr: 'table-row', + table: 'table', + li: 'list-item', + ul: 'block', + ol: 'block', + button: 'inline', + // Add more if needed + }; + return defaultDisplays[tag] || 'block'; // Default to 'block' if not listed +} + async function getAppConfig() { try { let res = await fetch("/get/app_config") @@ -478,14 +501,16 @@ async function getAppConfig() { modelsDirField.value = config.models_dir let testDiffusersEnabled = true - if (config.use_v3_engine === false) { + if (config.backend === "ed_classic") { testDiffusersEnabled = false } testDiffusers.checked = testDiffusersEnabled + backendEngine.value = config.backend document.querySelector("#test_diffusers").checked = testDiffusers.checked // don't break plugins + document.querySelector("#use_v3_engine").checked = testDiffusers.checked // don't break plugins if (config.config_on_startup) { - if (config.config_on_startup?.use_v3_engine) { + if (config.config_on_startup?.backend !== "ed_classic") { document.body.classList.add("diffusers-enabled-on-startup") document.body.classList.remove("diffusers-disabled-on-startup") } else { @@ -494,37 +519,27 @@ async function getAppConfig() { } } - if (!testDiffusersEnabled) { - document.querySelector("#lora_model_container").style.display = "none" - document.querySelector("#tiling_container").style.display = "none" - document.querySelector("#controlnet_model_container").style.display = "none" - document.querySelector("#hypernetwork_model_container").style.display = "" - document.querySelector("#hypernetwork_strength_container").style.display = "" - document.querySelector("#negative-embeddings-button").style.display = "none" - - document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => { - option.style.display = "none" - }) + if (config.backend === "ed_classic") { IMAGE_STEP_SIZE = 64 - customWidthField.step = IMAGE_STEP_SIZE - customHeightField.step = IMAGE_STEP_SIZE } else { - document.querySelector("#lora_model_container").style.display = "" - document.querySelector("#tiling_container").style.display = "" - document.querySelector("#controlnet_model_container").style.display = "" - document.querySelector("#hypernetwork_model_container").style.display = "none" - document.querySelector("#hypernetwork_strength_container").style.display = "none" - - document.querySelectorAll("#sampler_name option.k_diffusion-only").forEach((option) => { - option.style.display = "none" - }) - document.querySelector("#clip_skip_config").classList.remove("displayNone") - document.querySelector("#embeddings-button").classList.remove("displayNone") IMAGE_STEP_SIZE = 8 - customWidthField.step = IMAGE_STEP_SIZE - customHeightField.step = IMAGE_STEP_SIZE } + customWidthField.step = IMAGE_STEP_SIZE + customHeightField.step = IMAGE_STEP_SIZE + + const currentBackendKey = "backend_" + config.backend + + document.querySelectorAll('.gated-feature').forEach((element) => { + const featureKeys = element.getAttribute('data-feature-keys').split(' ') + + if (featureKeys.includes(currentBackendKey)) { + element.style.display = getDefaultDisplay(element) + } else { + element.style.display = 'none' + } + }); + if (config.force_save_metadata) { metadataOutputFormatField.value = config.force_save_metadata } @@ -749,6 +764,11 @@ async function getSystemInfo() { metadataOutputFormatField.disabled = !saveToDiskField.checked } setDiskPath(res["default_output_dir"], force) + + // backend info + if (res["backend_url"]) { + document.querySelector("#backend-url").setAttribute("href", res["backend_url"]) + } } catch (e) { console.log("error fetching devices", e) }