diff --git a/CHANGES.md b/CHANGES.md index b6a56186..cf26479a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,21 @@ # What's new? +## v3.5 (preview) +### Major Changes +- **Flux** - full support for the Flux model, including quantized bnb and nf4 models. +- **LyCORIS** - including `LoCon`, `Hada`, `IA3` and `Lokr`. +- **11 new samplers** - `DDIM CFG++`, `DPM Fast`, `DPM++ 2m SDE Heun`, `DPM++ 3M SDE`, `Restart`, `Heun PP2`, `IPNDM`, `IPNDM_V`, `LCM`, `[Forge] Flux Realistic`, `[Forge] Flux Realistic (Slow)`. +- **15 new schedulers** - `Uniform`, `Karras`, `Exponential`, `Polyexponential`, `SGM Uniform`, `KL Optimal`, `Align Your Steps`, `Normal`, `DDIM`, `Beta`, `Turbo`, `Align Your Steps GITS`, `Align Your Steps 11`, `Align Your Steps 32`. +- **42 new Controlnet filters, and support for lots of new ControlNet models** (including QR ControlNets). +- **5 upscalers** - `SwinIR`, `ScuNET`, `Nearest`, `Lanczos`, `ESRGAN`. +- **Faster than v3.0** +- **Major rewrite of the code** - We've switched to `Forge WebUI` under the hood, which brings a lot of new features, faster image generation, and support for all the extensions in the Forge/Automatic1111 community. This allows Easy Diffusion to stay up-to-date with the latest features, and focus on making the UI and installation experience even easier. + +v3.5 is currently an optional upgrade, and you can switch between the v3.0 (diffusers) engine and the v3.5 (webui) engine using the `Settings` tab in the UI. + +### Detailed changelog +* 3.5.0 - 11 Oct 2024 - **Preview release** of the new v3.5 engine, powered by Forge WebUI (a fork of Automatic1111). This enables Flux, SD3, LyCORIS and lots of new features, while using the same familiar Easy Diffusion interface. + ## v3.0 ### Major Changes - **ControlNet** - Full support for ControlNet, with native integration of the common ControlNet models. Just select a control image, then choose the ControlNet filter/model and run. No additional configuration or download necessary. Supports custom ControlNets as well. @@ -17,6 +33,7 @@ - **Major rewrite of the code** - We've switched to using diffusers under-the-hood, which allows us to release new features faster, and focus on making the UI and installer even easier to use. ### Detailed changelog +* 3.0.10 - 11 Oct 2024 - **Major Update** - An option to upgrade to v3.5, which enables Flux, Stable Diffusion 3, LyCORIS models and lots more. * 3.0.9 - 28 May 2024 - Slider for controlling the strength of controlnets. * 3.0.8 - 27 May 2024 - SDXL ControlNets for Img2Img and Inpainting. * 3.0.7 - 11 Dec 2023 - Setting to enable/disable VAE tiling (in the Image Settings panel). Sometimes VAE tiling reduces the quality of the image, so this setting will help control that. diff --git a/NSIS/sdui.nsi b/NSIS/sdui.nsi index 92bec886..2995328c 100644 --- a/NSIS/sdui.nsi +++ b/NSIS/sdui.nsi @@ -235,8 +235,8 @@ Section "MainSection" SEC01 CreateDirectory "$SMPROGRAMS\Easy Diffusion" CreateShortCut "$SMPROGRAMS\Easy Diffusion\Easy Diffusion.lnk" "$INSTDIR\Start Stable Diffusion UI.cmd" "" "$INSTDIR\installer_files\cyborg_flower_girl.ico" - DetailPrint 'Downloading the Stable Diffusion 1.5 model...' - NScurl::http get "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/sd-v1-5.safetensors" "$INSTDIR\models\stable-diffusion\sd-v1-5.safetensors" /CANCEL /INSIST /END + DetailPrint 'Downloading the Stable Diffusion 1.4 model...' + NScurl::http get "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/sd-v1-4.safetensors" "$INSTDIR\models\stable-diffusion\sd-v1-4.safetensors" /CANCEL /INSIST /END DetailPrint 'Downloading the GFPGAN model...' NScurl::http get "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth" "$INSTDIR\models\gfpgan\GFPGANv1.4.pth" /CANCEL /INSIST /END diff --git a/scripts/Start Stable Diffusion UI.cmd b/scripts/Start Stable Diffusion UI.cmd index 2bd54845..b4e486b8 100644 --- a/scripts/Start Stable Diffusion UI.cmd +++ b/scripts/Start Stable Diffusion UI.cmd @@ -3,7 +3,7 @@ cd /d %~dp0 echo Install dir: %~dp0 -set PATH=C:\Windows\System32;%PATH% +set PATH=C:\Windows\System32;C:\Windows\System32\wbem;%PATH% set PYTHONHOME= if exist "on_sd_start.bat" ( diff --git a/scripts/check_modules.py b/scripts/check_modules.py index d695ee03..6b7c5d31 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -34,6 +34,7 @@ modules_to_check = { "sqlalchemy": "2.0.19", "python-multipart": "0.0.6", # "xformers": "0.0.16", + "onnxruntime": "1.19.2", "huggingface-hub": "0.21.4", } modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"] @@ -298,7 +299,7 @@ Thanks!""" def get_config(): config_directory = os.path.dirname(__file__) # this will be "scripts" - config_yaml = os.path.join(config_directory, "..", "config.yaml") + config_yaml = os.path.abspath(os.path.join(config_directory, "..", "config.yaml")) config_json = os.path.join(config_directory, "config.json") config = None diff --git a/scripts/get_config.py b/scripts/get_config.py index 0bcc90a1..58f0e9ff 100644 --- a/scripts/get_config.py +++ b/scripts/get_config.py @@ -6,6 +6,7 @@ import shutil # The config file is in the same directory as this script config_directory = os.path.dirname(__file__) config_yaml = os.path.join(config_directory, "..", "config.yaml") +config_yaml = os.path.abspath(config_yaml) config_json = os.path.join(config_directory, "config.json") parser = argparse.ArgumentParser(description='Get values from config file') diff --git a/scripts/on_env_start.bat b/scripts/on_env_start.bat index 0eb8a1ba..1d3d8aa2 100644 --- a/scripts/on_env_start.bat +++ b/scripts/on_env_start.bat @@ -71,6 +71,7 @@ if "%update_branch%"=="" ( @copy sd-ui-files\scripts\check_modules.py scripts\ /Y @copy sd-ui-files\scripts\get_config.py scripts\ /Y @copy sd-ui-files\scripts\config.yaml.sample scripts\ /Y +@copy sd-ui-files\scripts\webui_console.py scripts\ /Y @copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y @copy "sd-ui-files\scripts\Developer Console.cmd" . /Y diff --git a/scripts/on_env_start.sh b/scripts/on_env_start.sh index 77f5a6ef..df681e37 100755 --- a/scripts/on_env_start.sh +++ b/scripts/on_env_start.sh @@ -54,6 +54,7 @@ cp sd-ui-files/scripts/bootstrap.sh scripts/ cp sd-ui-files/scripts/check_modules.py scripts/ cp sd-ui-files/scripts/get_config.py scripts/ cp sd-ui-files/scripts/config.yaml.sample scripts/ +cp sd-ui-files/scripts/webui_console.py scripts/ cp sd-ui-files/scripts/start.sh . cp sd-ui-files/scripts/developer_console.sh . cp sd-ui-files/scripts/functions.sh scripts/ diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 51668fbd..2d5ff0ba 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -7,6 +7,7 @@ @copy sd-ui-files\scripts\check_modules.py scripts\ /Y @copy sd-ui-files\scripts\get_config.py scripts\ /Y @copy sd-ui-files\scripts\config.yaml.sample scripts\ /Y +@copy sd-ui-files\scripts\webui_console.py scripts\ /Y if exist "%cd%\profile" ( set HF_HOME=%cd%\profile\.cache\huggingface diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index fbd39f8c..a9fc809d 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -6,16 +6,20 @@ cp sd-ui-files/scripts/bootstrap.sh scripts/ cp sd-ui-files/scripts/check_modules.py scripts/ cp sd-ui-files/scripts/get_config.py scripts/ cp sd-ui-files/scripts/config.yaml.sample scripts/ +cp sd-ui-files/scripts/webui_console.py scripts/ source ./scripts/functions.sh # activate the installer env -CONDA_BASEPATH=$(conda info --base) +export CONDA_BASEPATH=$(conda info --base) source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # avoids the 'shell not initialized' error conda activate || fail "Failed to activate conda" +# hack to fix conda 4.14 on older installations +cp $CONDA_BASEPATH/condabin/conda $CONDA_BASEPATH/bin/conda + # remove the old version of the dev console script, if it's still present if [ -e "open_dev_console.sh" ]; then rm "open_dev_console.sh" diff --git a/scripts/webui_console.py b/scripts/webui_console.py new file mode 100644 index 00000000..6274e1df --- /dev/null +++ b/scripts/webui_console.py @@ -0,0 +1,101 @@ +import os +import platform +import subprocess + + +def configure_env(dir): + env_entries = { + "PATH": [ + f"{dir}", + f"{dir}/bin", + f"{dir}/Library/bin", + f"{dir}/Scripts", + f"{dir}/usr/bin", + ], + "PYTHONPATH": [ + f"{dir}", + f"{dir}/lib/site-packages", + f"{dir}/lib/python3.10/site-packages", + ], + "PYTHONHOME": [], + "PY_LIBS": [ + f"{dir}/Scripts/Lib", + f"{dir}/Scripts/Lib/site-packages", + f"{dir}/lib", + f"{dir}/lib/python3.10/site-packages", + ], + "PY_PIP": [f"{dir}/Scripts", f"{dir}/bin"], + } + + if platform.system() == "Windows": + env_entries["PATH"].append("C:/Windows/System32") + env_entries["PATH"].append("C:/Windows/System32/wbem") + env_entries["PYTHONNOUSERSITE"] = ["1"] + env_entries["PYTHON"] = [f"{dir}/python"] + env_entries["GIT"] = [f"{dir}/Library/bin/git"] + else: + env_entries["PATH"].append("/bin") + env_entries["PATH"].append("/usr/bin") + env_entries["PATH"].append("/usr/sbin") + env_entries["PYTHONNOUSERSITE"] = ["y"] + env_entries["PYTHON"] = [f"{dir}/bin/python"] + env_entries["GIT"] = [f"{dir}/bin/git"] + + env = {} + for key, paths in env_entries.items(): + paths = [p.replace("/", os.path.sep) for p in paths] + paths = os.pathsep.join(paths) + + os.environ[key] = paths + + return env + + +def print_env_info(): + which_cmd = "where" if platform.system() == "Windows" else "which" + + python = "python" + + def locate_python(): + nonlocal python + + python = subprocess.getoutput(f"{which_cmd} python") + python = python.split("\n") + python = python[0].strip() + print("python: ", python) + + locate_python() + + def run(cmd): + with subprocess.Popen(cmd) as p: + p.wait() + + run([which_cmd, "git"]) + run(["git", "--version"]) + run([which_cmd, "python"]) + run([python, "--version"]) + + print(f"PATH={os.environ['PATH']}") + + if platform.system() == "Windows": + print(f"COMSPEC={os.environ['COMSPEC']}") + print("") + run("wmic path win32_VideoController get name,AdapterRAM,DriverDate,DriverVersion".split(" ")) + + print(f"PYTHONPATH={os.environ['PYTHONPATH']}") + print("") + + +def open_dev_shell(): + if platform.system() == "Windows": + subprocess.Popen("cmd").communicate() + else: + subprocess.Popen("bash").communicate() + + +if __name__ == "__main__": + env_dir = os.path.abspath(os.path.join("webui", "system")) + + configure_env(env_dir) + print_env_info() + open_dev_shell() diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 43d0e3c4..de618bde 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -11,7 +11,7 @@ from ruamel.yaml import YAML import urllib import warnings -from easydiffusion import task_manager +from easydiffusion import task_manager, backend_manager from easydiffusion.utils import log from rich.logging import RichHandler from rich.console import Console @@ -36,10 +36,10 @@ ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, "..")) SD_UI_DIR = os.getenv("SD_UI_PATH", None) -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) -BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket")) +CONFIG_DIR = os.path.abspath(os.path.join(ROOT_DIR, "scripts")) +BUCKET_DIR = os.path.abspath(os.path.join(ROOT_DIR, "bucket")) -USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) +USER_PLUGINS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "plugins")) CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui") @@ -60,7 +60,7 @@ APP_CONFIG_DEFAULTS = { "ui": { "open_browser_on_start": True, }, - "use_v3_engine": True, + "backend": "ed_diffusers", } IMAGE_EXTENSIONS = [ @@ -77,7 +77,7 @@ IMAGE_EXTENSIONS = [ ".avif", ".svg", ] -CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers")) +CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "modifiers")) CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [ ".portrait", "_portrait", @@ -91,7 +91,7 @@ CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [ "-landscape", ] -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) +MODELS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "models")) def init(): @@ -105,9 +105,11 @@ def init(): config = getConfig() config_models_dir = config.get("models_dir", None) - if (config_models_dir is not None and config_models_dir != ""): + if config_models_dir is not None and config_models_dir != "": MODELS_DIR = config_models_dir + backend_manager.start_backend() + def init_render_threads(): load_server_plugins() @@ -117,6 +119,7 @@ def init_render_threads(): def getConfig(default_val=APP_CONFIG_DEFAULTS): config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml") + config_yaml_path = os.path.abspath(config_yaml_path) # migrate the old config yaml location config_legacy_yaml = os.path.join(CONFIG_DIR, "config.yaml") @@ -124,9 +127,9 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): shutil.move(config_legacy_yaml, config_yaml_path) def set_config_on_startup(config: dict): - if getConfig.__use_v3_engine_on_startup is None: - getConfig.__use_v3_engine_on_startup = config.get("use_v3_engine", True) - config["config_on_startup"] = {"use_v3_engine": getConfig.__use_v3_engine_on_startup} + if getConfig.__use_backend_on_startup is None: + getConfig.__use_backend_on_startup = config.get("backend", "ed_diffusers") + config["config_on_startup"] = {"backend": getConfig.__use_backend_on_startup} if os.path.isfile(config_yaml_path): try: @@ -144,6 +147,15 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): else: config["net"]["listen_to_network"] = True + if "backend" not in config: + if "use_v3_engine" in config: + config["backend"] = "ed_diffusers" if config["use_v3_engine"] else "ed_classic" + else: + config["backend"] = "ed_diffusers" + # this default will need to be smarter when WebUI becomes the main backend, but needs to maintain backwards + # compatibility with existing ED 3.0 installations that haven't opted into the WebUI backend, and haven't + # set a "use_v3_engine" flag in their config + set_config_on_startup(config) return config @@ -174,7 +186,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): return default_val -getConfig.__use_v3_engine_on_startup = None +getConfig.__use_backend_on_startup = None def setConfig(config): @@ -307,28 +319,43 @@ def getIPConfig(): def open_browser(): + from easydiffusion.backend_manager import backend + config = getConfig() ui = config.get("ui", {}) net = config.get("net", {}) port = net.get("listen_port", 9000) - if ui.get("open_browser_on_start", True): - import webbrowser + if backend.is_installed(): + if ui.get("open_browser_on_start", True): + import webbrowser - log.info("Opening browser..") + log.info("Opening browser..") - webbrowser.open(f"http://localhost:{port}") + webbrowser.open(f"http://localhost:{port}") - Console().print( - Panel( - "\n" - + "[white]Easy Diffusion is ready to serve requests.\n\n" - + "A new browser tab should have been opened by now.\n" - + f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", - title="Easy Diffusion is ready", - style="bold yellow on blue", + Console().print( + Panel( + "\n" + + "[white]Easy Diffusion is ready to serve requests.\n\n" + + "A new browser tab should have been opened by now.\n" + + f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", + title="Easy Diffusion is ready", + style="bold yellow on blue", + ) + ) + else: + backend_name = config["backend"] + Console().print( + Panel( + "\n" + + f"[white]Backend: {backend_name} is still installing..\n\n" + + "A new browser tab will open automatically after it finishes.\n" + + f"If it does not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", + title=f"Backend engine is installing", + style="bold yellow on blue", + ) ) - ) def fail_and_die(fail_type: str, data: str): diff --git a/ui/easydiffusion/backend_manager.py b/ui/easydiffusion/backend_manager.py new file mode 100644 index 00000000..1c280db4 --- /dev/null +++ b/ui/easydiffusion/backend_manager.py @@ -0,0 +1,105 @@ +import os +import ast +import sys +import importlib.util +import traceback + +from easydiffusion.utils import log + +backend = None +curr_backend_name = None + + +def is_valid_backend(file_path): + with open(file_path, "r", encoding="utf-8") as file: + node = ast.parse(file.read()) + + # Check for presence of a dictionary named 'ed_info' + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name) and target.id == "ed_info": + return True + return False + + +def find_valid_backends(root_dir) -> dict: + backends_path = os.path.join(root_dir, "backends") + valid_backends = {} + + if not os.path.exists(backends_path): + return valid_backends + + for item in os.listdir(backends_path): + item_path = os.path.join(backends_path, item) + + if os.path.isdir(item_path): + init_file = os.path.join(item_path, "__init__.py") + if os.path.exists(init_file) and is_valid_backend(init_file): + valid_backends[item] = item_path + elif item.endswith(".py"): + if is_valid_backend(item_path): + backend_name = os.path.splitext(item)[0] # strip the .py extension + valid_backends[backend_name] = item_path + + return valid_backends + + +def load_backend_module(backend_name, backend_dict): + if backend_name not in backend_dict: + raise ValueError(f"Backend '{backend_name}' not found.") + + module_path = backend_dict[backend_name] + + mod_dir = os.path.dirname(module_path) + + sys.path.insert(0, mod_dir) + + # If it's a package (directory), add its parent directory to sys.path + if os.path.isdir(module_path): + module_path = os.path.join(module_path, "__init__.py") + + spec = importlib.util.spec_from_file_location(backend_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if mod_dir in sys.path: + sys.path.remove(mod_dir) + + log.info(f"Loaded backend: {module}") + + return module + + +def start_backend(): + global backend, curr_backend_name + + from easydiffusion.app import getConfig, ROOT_DIR + + curr_dir = os.path.dirname(__file__) + + backends = find_valid_backends(curr_dir) + plugin_backends = find_valid_backends(ROOT_DIR) + backends.update(plugin_backends) + + config = getConfig() + backend_name = config["backend"] + + if backend_name not in backends: + raise RuntimeError( + f"Couldn't find the backend configured in config.yaml: {backend_name}. Please check the name!" + ) + + if backend is not None and backend_name != curr_backend_name: + try: + backend.stop_backend() + except: + log.exception(traceback.format_exc()) + + log.info(f"Loading backend: {backend_name}") + backend = load_backend_module(backend_name, backends) + + try: + backend.start_backend() + except: + log.exception(traceback.format_exc()) diff --git a/ui/easydiffusion/backends/ed_classic.py b/ui/easydiffusion/backends/ed_classic.py new file mode 100644 index 00000000..07077b6c --- /dev/null +++ b/ui/easydiffusion/backends/ed_classic.py @@ -0,0 +1,28 @@ +from sdkit_common import ( + start_backend, + stop_backend, + install_backend, + uninstall_backend, + is_installed, + create_sdkit_context, + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + +ed_info = { + "name": "Classic backend for Easy Diffusion v2", + "version": (1, 0, 0), + "type": "backend", +} + + +def create_context(): + return create_sdkit_context(use_diffusers=False) diff --git a/ui/easydiffusion/backends/ed_diffusers.py b/ui/easydiffusion/backends/ed_diffusers.py new file mode 100644 index 00000000..b5cf8744 --- /dev/null +++ b/ui/easydiffusion/backends/ed_diffusers.py @@ -0,0 +1,28 @@ +from sdkit_common import ( + start_backend, + stop_backend, + install_backend, + uninstall_backend, + is_installed, + create_sdkit_context, + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + +ed_info = { + "name": "Diffusers Backend for Easy Diffusion v3", + "version": (1, 0, 0), + "type": "backend", +} + + +def create_context(): + return create_sdkit_context(use_diffusers=True) diff --git a/ui/easydiffusion/backends/sdkit_common.py b/ui/easydiffusion/backends/sdkit_common.py new file mode 100644 index 00000000..8d6c377b --- /dev/null +++ b/ui/easydiffusion/backends/sdkit_common.py @@ -0,0 +1,246 @@ +from sdkit import Context + +from easydiffusion.types import UserInitiatedStop + +from sdkit.utils import ( + diffusers_latent_samples_to_images, + gc, + img_to_base64_str, + latent_samples_to_images, +) + +opts = {} + + +def install_backend(): + pass + + +def start_backend(): + print("Started sdkit backend") + + +def stop_backend(): + pass + + +def uninstall_backend(): + pass + + +def is_installed(): + return True + + +def create_sdkit_context(use_diffusers): + c = Context() + c.test_diffusers = use_diffusers + return c + + +def ping(timeout=1): + return True + + +def load_model(context, model_type, **kwargs): + from sdkit.models import load_model + + load_model(context, model_type, **kwargs) + + +def unload_model(context, model_type, **kwargs): + from sdkit.models import unload_model + + unload_model(context, model_type, **kwargs) + + +def set_options(context, **kwargs): + if "vae_tiling" in kwargs and context.test_diffusers: + pipe = context.models["stable-diffusion"]["default"] + vae_tiling = kwargs["vae_tiling"] + + if vae_tiling: + if hasattr(pipe, "enable_vae_tiling"): + pipe.enable_vae_tiling() + else: + if hasattr(pipe, "disable_vae_tiling"): + pipe.disable_vae_tiling() + + for key in ( + "output_format", + "output_quality", + "output_lossless", + "stream_image_progress", + "stream_image_progress_interval", + ): + if key in kwargs: + opts[key] = kwargs[key] + + +def generate_images( + context: Context, + callback=None, + controlnet_filter=None, + distilled_guidance_scale: float = 3.5, + scheduler_name: str = "simple", + output_type="pil", + **req, +): + from sdkit.generate import generate_images + + if req["init_image"] is not None and not context.test_diffusers: + req["sampler_name"] = "ddim" + + gc(context) + + context.stop_processing = False + + if req["control_image"] and controlnet_filter: + controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter) + req["control_image"] = filter_images(context, req["control_image"], controlnet_filter)[0] + + callback = make_step_callback(context, callback) + + try: + images = generate_images(context, callback=callback, **req) + except UserInitiatedStop: + images = [] + if context.partial_x_samples is not None: + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) + finally: + if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: + if not context.test_diffusers: + del context.partial_x_samples + context.partial_x_samples = None + + gc(context) + + if output_type == "base64": + output_format = opts.get("output_format", "jpeg") + output_quality = opts.get("output_quality", 75) + output_lossless = opts.get("output_lossless", False) + images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images] + + return images + + +def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"): + gc(context) + + if "nsfw_checker" in filters: + filters.remove("nsfw_checker") # handled by ED directly + + if len(filters) == 0: + return images + + images = _filter_images(context, images, filters, filter_params) + + if input_type == "base64": + output_format = opts.get("output_format", "jpg") + output_quality = opts.get("output_quality", 75) + output_lossless = opts.get("output_lossless", False) + images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images] + + return images + + +def _filter_images(context, images, filters, filter_params={}): + from sdkit.filter import apply_filters + + filters = filters if isinstance(filters, list) else [filters] + filters = convert_ED_controlnet_filter_name(filters) + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + previous_state = before_filter(context, filter_name, params) + + try: + images = apply_filters(context, filter_name, images, **params) + finally: + after_filter(context, filter_name, params, previous_state) + + return images + + +def before_filter(context, filter_name, filter_params): + if filter_name == "codeformer": + from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use + + default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] + prev_realesrgan_path = None + + upscale_faces = filter_params.get("upscale_faces", False) + if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: + prev_realesrgan_path = context.model_paths.get("realesrgan") + context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") + load_model(context, "realesrgan") + + return prev_realesrgan_path + + +def after_filter(context, filter_name, filter_params, previous_state): + if filter_name == "codeformer": + prev_realesrgan_path = previous_state + if prev_realesrgan_path: + context.model_paths["realesrgan"] = prev_realesrgan_path + load_model(context, "realesrgan") + + +def get_url(): + pass + + +def stop_rendering(context): + context.stop_processing = True + + +def refresh_models(): + pass + + +def list_controlnet_filters(): + from sdkit.models.model_loader.controlnet_filters import filters as cn_filters + + return cn_filters + + +def make_step_callback(context, callback): + def on_step(x_samples, i, *args): + stream_image_progress = opts.get("stream_image_progress", False) + stream_image_progress_interval = opts.get("stream_image_progress_interval", 3) + + if context.test_diffusers: + context.partial_x_samples = (x_samples, args[0]) + else: + context.partial_x_samples = x_samples + + if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) + else: + images = None + + if callback: + callback(images, i, *args) + + if context.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + + return on_step + + +def convert_ED_controlnet_filter_name(filter): + def cn(n): + if n.startswith("controlnet_"): + return n[len("controlnet_") :] + return n + + if isinstance(filter, list): + return [cn(f) for f in filter] + return cn(filter) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py new file mode 100644 index 00000000..947b2e24 --- /dev/null +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -0,0 +1,450 @@ +import os +import platform +import subprocess +import threading +from threading import local +import psutil +import time +import shutil + +from easydiffusion.app import ROOT_DIR, getConfig +from easydiffusion.model_manager import get_model_dirs +from easydiffusion.utils import log + +from . import impl +from .impl import ( + ping, + load_model, + unload_model, + set_options, + generate_images, + filter_images, + get_url, + stop_rendering, + refresh_models, + list_controlnet_filters, +) + + +ed_info = { + "name": "WebUI backend for Easy Diffusion", + "version": (1, 0, 0), + "type": "backend", +} + +WEBUI_REPO = "https://github.com/lllyasviel/stable-diffusion-webui-forge.git" +WEBUI_COMMIT = "b592142f3b46852263747a4efd0d244ad17b5bb3" + +BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) +SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") +WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") + +OS_NAME = platform.system() + +MODELS_TO_OVERRIDE = { + "stable-diffusion": "--ckpt-dir", + "vae": "--vae-dir", + "hypernetwork": "--hypernetwork-dir", + "gfpgan": "--gfpgan-models-path", + "realesrgan": "--realesrgan-models-path", + "lora": "--lora-dir", + "codeformer": "--codeformer-models-path", + "embeddings": "--embeddings-dir", + "controlnet": "--controlnet-dir", +} + +backend_process = None +conda = "conda" + + +def locate_conda(): + global conda + + which = "where" if OS_NAME == "Windows" else "which" + conda = subprocess.getoutput(f"{which} conda") + conda = conda.split("\n") + conda = conda[0].strip() + print("conda: ", conda) + + +locate_conda() + + +def install_backend(): + print("Installing the WebUI backend..") + + # create the conda env + run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR) + + print("Installing packages..") + + # install python 3.10 and git in the conda env + run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR) + + # print info + run_in_conda(["git", "--version"], cwd=ROOT_DIR) + run_in_conda(["python", "--version"], cwd=ROOT_DIR) + + # clone webui + run_in_conda(["git", "clone", WEBUI_REPO, WEBUI_DIR], cwd=ROOT_DIR) + + # install cpu-only torch if the PC doesn't have a graphics card (for Windows and Linux). + # this avoids WebUI installing a CUDA version and trying to activate it + if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card(): + run_in_conda(["python", "-m", "pip", "install", "torch", "torchvision"], cwd=WEBUI_DIR) + + +def start_backend(): + config = getConfig() + backend_config = config.get("backend_config", {}) + + if not os.path.exists(BACKEND_DIR): + install_backend() + + was_still_installing = not is_installed() + + if backend_config.get("auto_update", True): + run_in_conda(["git", "add", "-A", "."], cwd=WEBUI_DIR) + run_in_conda(["git", "stash"], cwd=WEBUI_DIR) + run_in_conda(["git", "reset", "--hard"], cwd=WEBUI_DIR) + run_in_conda(["git", "fetch"], cwd=WEBUI_DIR) + run_in_conda(["git", "-c", "advice.detachedHead=false", "checkout", WEBUI_COMMIT], cwd=WEBUI_DIR) + + # hack to prevent webui-macos-env.sh from overwriting the COMMANDLINE_ARGS env variable + mac_webui_file = os.path.join(WEBUI_DIR, "webui-macos-env.sh") + if os.path.exists(mac_webui_file): + os.remove(mac_webui_file) + + impl.WEBUI_HOST = backend_config.get("host", "localhost") + impl.WEBUI_PORT = backend_config.get("port", "7860") + + env = dict(os.environ) + env.update(get_env()) + + def restart_if_webui_dies_after_starting(): + has_started = False + + while True: + try: + impl.ping(timeout=30) + + is_first_start = not has_started + has_started = True + + if was_still_installing and is_first_start: + ui = config.get("ui", {}) + net = config.get("net", {}) + port = net.get("listen_port", 9000) + + if ui.get("open_browser_on_start", True): + import webbrowser + + log.info("Opening browser..") + + webbrowser.open(f"http://localhost:{port}") + except (TimeoutError, ConnectionError): + if has_started: # process probably died + print("######################## WebUI probably died. Restarting...") + stop_backend() + backend_thread = threading.Thread(target=target) + backend_thread.start() + break + except Exception: + import traceback + + log.exception(traceback.format_exc()) + + time.sleep(1) + + def target(): + global backend_process + + cmd = "webui.bat" if OS_NAME == "Windows" else "./webui.sh" + + print("starting", cmd, WEBUI_DIR) + backend_process = run_in_conda([cmd], cwd=WEBUI_DIR, env=env, wait=False, output_prefix="[WebUI] ") + + restart_if_dead_thread = threading.Thread(target=restart_if_webui_dies_after_starting) + restart_if_dead_thread.start() + + backend_process.wait() + + backend_thread = threading.Thread(target=target) + backend_thread.start() + + start_proxy() + + +def start_proxy(): + # proxy + from easydiffusion.server import server_api + from fastapi import FastAPI, Request + from fastapi.responses import Response + import json + + URI_PREFIX = "/webui" + + webui_proxy = FastAPI(root_path=f"{URI_PREFIX}", docs_url="/swagger") + + @webui_proxy.get("{uri:path}") + def proxy_get(uri: str, req: Request): + if uri == "/openapi-proxy.json": + uri = "/openapi.json" + + res = impl.webui_get(uri, headers=req.headers) + + content = res.content + headers = dict(res.headers) + + if uri == "/docs": + content = res.text.replace("url: '/openapi.json'", f"url: '{URI_PREFIX}/openapi-proxy.json'") + elif uri == "/openapi.json": + content = res.json() + content["paths"] = {f"{URI_PREFIX}{k}": v for k, v in content["paths"].items()} + content = json.dumps(content) + + if isinstance(content, str): + content = bytes(content, encoding="utf-8") + headers["content-length"] = str(len(content)) + + # Return the same response back to the client + return Response(content=content, status_code=res.status_code, headers=headers) + + @webui_proxy.post("{uri:path}") + async def proxy_post(uri: str, req: Request): + body = await req.body() + res = impl.webui_post(uri, data=body, headers=req.headers) + + # Return the same response back to the client + return Response(content=res.content, status_code=res.status_code, headers=dict(res.headers)) + + server_api.mount(f"{URI_PREFIX}", webui_proxy) + + +def stop_backend(): + global backend_process + + if backend_process: + try: + kill(backend_process.pid) + except: + pass + + backend_process = None + + +def uninstall_backend(): + shutil.rmtree(BACKEND_DIR) + + +def is_installed(): + if not os.path.exists(BACKEND_DIR) or not os.path.exists(SYSTEM_DIR) or not os.path.exists(WEBUI_DIR): + return True + + env = dict(os.environ) + env.update(get_env()) + + try: + out = check_output_in_conda(["python", "-m", "pip", "show", "torch"], env=env) + return "Version" in out.decode() + except subprocess.CalledProcessError: + pass + + return False + + +def read_output(pipe, prefix=""): + while True: + output = pipe.readline() + if output: + print(f"{prefix}{output.decode('utf-8')}", end="") + else: + break # Pipe is closed, subprocess has likely exited + + +def run(cmds: list, cwd=None, env=None, stream_output=True, wait=True, output_prefix=""): + p = subprocess.Popen(cmds, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + if stream_output: + output_thread = threading.Thread(target=read_output, args=(p.stdout, output_prefix)) + output_thread.start() + + if wait: + p.wait() + + return p + + +def run_in_conda(cmds: list, *args, **kwargs): + cmds = [conda, "run", "--no-capture-output", "--prefix", SYSTEM_DIR] + cmds + return run(cmds, *args, **kwargs) + + +def check_output_in_conda(cmds: list, cwd=None, env=None): + cmds = [conda, "run", "--no-capture-output", "--prefix", SYSTEM_DIR] + cmds + return subprocess.check_output(cmds, cwd=cwd, env=env, stderr=subprocess.PIPE) + + +def create_context(): + context = local() + + # temp hack, throws an attribute not found error otherwise + context.device = "cuda:0" + context.half_precision = True + context.vram_usage_level = None + + context.models = {} + context.model_paths = {} + context.model_configs = {} + context.device_name = None + context.vram_optimizations = set() + context.vram_usage_level = "balanced" + context.test_diffusers = False + context.enable_codeformer = False + + return context + + +def get_env(): + dir = os.path.abspath(SYSTEM_DIR) + + if not os.path.exists(dir): + raise RuntimeError("The system folder is missing!") + + config = getConfig() + models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) + + model_path_args = get_model_path_args() + + env_entries = { + "PATH": [ + f"{dir}", + f"{dir}/bin", + f"{dir}/Library/bin", + f"{dir}/Scripts", + f"{dir}/usr/bin", + ], + "PYTHONPATH": [ + f"{dir}", + f"{dir}/lib/site-packages", + f"{dir}/lib/python3.10/site-packages", + ], + "PYTHONHOME": [], + "PY_LIBS": [ + f"{dir}/Scripts/Lib", + f"{dir}/Scripts/Lib/site-packages", + f"{dir}/lib", + f"{dir}/lib/python3.10/site-packages", + ], + "PY_PIP": [f"{dir}/Scripts", f"{dir}/bin"], + "PIP_INSTALLER_LOCATION": [], # [f"{dir}/python/get-pip.py"], + "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], + "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], + "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args} --skip-torch-cuda-test'], + "SKIP_VENV": ["1"], + "SD_WEBUI_RESTARTING": ["1"], + } + + if OS_NAME == "Windows": + env_entries["PATH"].append("C:/Windows/System32") + env_entries["PATH"].append("C:/Windows/System32/wbem") + env_entries["PYTHONNOUSERSITE"] = ["1"] + env_entries["PYTHON"] = [f"{dir}/python"] + env_entries["GIT"] = [f"{dir}/Library/bin/git"] + else: + env_entries["PATH"].append("/bin") + env_entries["PATH"].append("/usr/bin") + env_entries["PATH"].append("/usr/sbin") + env_entries["PYTHONNOUSERSITE"] = ["y"] + env_entries["PYTHON"] = [f"{dir}/bin/python"] + env_entries["GIT"] = [f"{dir}/bin/git"] + env_entries["venv_dir"] = ["-"] + + if OS_NAME == "Darwin": + # based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/e26abf87ecd1eefd9ab0a198eee56f9c643e4001/webui-macos-env.sh + # hack - have to define these here, otherwise webui-macos-env.sh will overwrite COMMANDLINE_ARGS + env_entries["COMMANDLINE_ARGS"][0] += " --upcast-sampling --no-half-vae --use-cpu interrogate" + env_entries["PYTORCH_ENABLE_MPS_FALLBACK"] = ["1"] + + cpu_name = str(subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])) + if "Intel" in cpu_name: + env_entries["TORCH_COMMAND"] = ["pip install torch==2.1.2 torchvision==0.16.2"] + else: + env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"] + else: + import torch + from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available + + vram_usage_level = config.get("vram_usage_level", "balanced") + if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available(): + env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" + else: + c = local() + c.device_name = torch.cuda.get_device_name() + + if needs_to_force_full_precision(c): + env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full" + + if vram_usage_level == "low": + env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram" + elif vram_usage_level == "high": + env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram" + + env = {} + for key, paths in env_entries.items(): + paths = [p.replace("/", os.path.sep) for p in paths] + paths = os.pathsep.join(paths) + + env[key] = paths + + return env + + +def has_discrete_graphics_card(): + system = OS_NAME + + if system == "Windows": + try: + output = subprocess.check_output( + ["wmic", "path", "win32_videocontroller", "get", "name"], stderr=subprocess.STDOUT + ) + # Filter for discrete graphics cards (NVIDIA, AMD, etc.) + discrete_gpus = ["NVIDIA", "AMD", "ATI"] + return any(gpu in output.decode() for gpu in discrete_gpus) + except subprocess.CalledProcessError: + return False + + elif system == "Linux": + try: + output = subprocess.check_output(["lspci"], stderr=subprocess.STDOUT) + # Check for discrete GPUs (NVIDIA, AMD) + discrete_gpus = ["NVIDIA", "AMD", "Advanced Micro Devices"] + return any(gpu in line for line in output.decode().splitlines() for gpu in discrete_gpus) + except subprocess.CalledProcessError: + return False + + elif system == "Darwin": # macOS + try: + output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], stderr=subprocess.STDOUT) + # Check for discrete GPU in the output + return "NVIDIA" in output.decode() or "AMD" in output.decode() + except subprocess.CalledProcessError: + return False + + return False + + +# https://stackoverflow.com/a/25134985 +def kill(proc_pid): + process = psutil.Process(proc_pid) + for proc in process.children(recursive=True): + proc.kill() + process.kill() + + +def get_model_path_args(): + args = [] + for model_type, flag in MODELS_TO_OVERRIDE.items(): + model_dir = get_model_dirs(model_type)[0] + args.append(f'{flag} "{model_dir}"') + + return " ".join(args) diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py new file mode 100644 index 00000000..0e644bd5 --- /dev/null +++ b/ui/easydiffusion/backends/webui/impl.py @@ -0,0 +1,654 @@ +import os +import requests +from requests.exceptions import ConnectTimeout, ConnectionError, ReadTimeout +from typing import Union, List +from threading import local as Context +from threading import Thread +import uuid +import time +from copy import deepcopy + +from sdkit.utils import base64_str_to_img, img_to_base64_str, log + +WEBUI_HOST = "localhost" +WEBUI_PORT = "7860" + +DEFAULT_WEBUI_OPTIONS = { + "show_progress_every_n_steps": 3, + "show_progress_grid": True, + "live_previews_enable": False, + "forge_additional_modules": [], +} + + +webui_opts: dict = None + + +curr_models = { + "stable-diffusion": None, + "vae": None, +} + + +def set_options(context, **kwargs): + changed_opts = {} + + opts_mapping = { + "stream_image_progress": ("live_previews_enable", bool), + "stream_image_progress_interval": ("show_progress_every_n_steps", int), + "clip_skip": ("CLIP_stop_at_last_layers", int), + "clip_skip_sdxl": ("sdxl_clip_l_skip", bool), + "output_format": ("samples_format", str), + } + + for ed_key, webui_key in opts_mapping.items(): + webui_key, webui_type = webui_key + + if ed_key in kwargs and (webui_opts is None or webui_opts.get(webui_key, False) != webui_type(kwargs[ed_key])): + changed_opts[webui_key] = webui_type(kwargs[ed_key]) + + if changed_opts: + changed_opts["sd_model_checkpoint"] = curr_models["stable-diffusion"] + + print(f"Got options: {kwargs}. Sending options: {changed_opts}") + + try: + res = webui_post("/sdapi/v1/options", json=changed_opts) + if res.status_code != 200: + raise Exception(res.text) + + webui_opts.update(changed_opts) + except Exception as e: + print(f"Error setting options: {e}") + + +def ping(timeout=1): + "timeout (in seconds)" + + global webui_opts + + try: + res = webui_get("/internal/ping", timeout=timeout) + + if res.status_code != 200: + raise ConnectTimeout(res.text) + + if webui_opts is None: + try: + res = webui_post("/sdapi/v1/options", json=DEFAULT_WEBUI_OPTIONS) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + print(f"Error setting options: {e}") + + try: + res = webui_get("/sdapi/v1/options") + if res.status_code != 200: + raise Exception(res.text) + + webui_opts = res.json() + except Exception as e: + print(f"Error getting options: {e}") + + return True + except (ConnectTimeout, ConnectionError, ReadTimeout) as e: + raise TimeoutError(e) + + +def load_model(context, model_type, **kwargs): + model_path = context.model_paths[model_type] + + if webui_opts is None: + print("Server not ready, can't set the model") + return + + if model_type == "stable-diffusion": + model_name = os.path.basename(model_path) + model_name = os.path.splitext(model_name)[0] + print(f"setting sd model: {model_name}") + if curr_models[model_type] != model_name: + try: + res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name}) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) + + curr_models[model_type] = model_name + elif model_type == "vae": + if curr_models[model_type] != model_path: + vae_model = [model_path] if model_path else [] + + opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model} + print("setting opts 2", opts) + + try: + res = webui_post("/sdapi/v1/options", json=opts) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) + + curr_models[model_type] = model_path + + +def unload_model(context, model_type, **kwargs): + if model_type == "vae": + context.model_paths[model_type] = None + load_model(context, model_type) + + +def generate_images( + context: Context, + prompt: str = "", + negative_prompt: str = "", + seed: int = 42, + width: int = 512, + height: int = 512, + num_outputs: int = 1, + num_inference_steps: int = 25, + guidance_scale: float = 7.5, + distilled_guidance_scale: float = 3.5, + init_image=None, + init_image_mask=None, + control_image=None, + control_alpha=1.0, + controlnet_filter=None, + prompt_strength: float = 0.8, + preserve_init_image_color_profile=False, + strict_mask_border=False, + sampler_name: str = "euler_a", + scheduler_name: str = "simple", + hypernetwork_strength: float = 0, + tiling=None, + lora_alpha: Union[float, List[float]] = 0, + sampler_params={}, + callback=None, + output_type="pil", +): + + task_id = str(uuid.uuid4()) + + sampler_name = convert_ED_sampler_names(sampler_name) + controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter) + + cmd = { + "force_task_id": task_id, + "prompt": prompt, + "negative_prompt": negative_prompt, + "sampler_name": sampler_name, + "scheduler": scheduler_name, + "steps": num_inference_steps, + "seed": seed, + "cfg_scale": guidance_scale, + "distilled_cfg_scale": distilled_guidance_scale, + "batch_size": num_outputs, + "width": width, + "height": height, + } + + if init_image: + cmd["init_images"] = [init_image] + cmd["denoising_strength"] = prompt_strength + if init_image_mask: + cmd["mask"] = init_image_mask + cmd["include_init_images"] = True + cmd["inpainting_fill"] = 1 + cmd["initial_noise_multiplier"] = 1 + cmd["inpaint_full_res"] = 1 + + if context.model_paths.get("lora"): + lora_model = context.model_paths["lora"] + lora_model = lora_model if isinstance(lora_model, list) else [lora_model] + lora_alpha = lora_alpha if isinstance(lora_alpha, list) else [lora_alpha] + + for lora, alpha in zip(lora_model, lora_alpha): + lora = os.path.basename(lora) + lora = os.path.splitext(lora)[0] + cmd["prompt"] += f" " + + if controlnet_filter and control_image and context.model_paths.get("controlnet"): + controlnet_model = context.model_paths["controlnet"] + + model_hash = auto1111_hash(controlnet_model) + controlnet_model = os.path.basename(controlnet_model) + controlnet_model = os.path.splitext(controlnet_model)[0] + print(f"setting controlnet model: {controlnet_model}") + controlnet_model = f"{controlnet_model} [{model_hash}]" + + cmd["alwayson_scripts"] = { + "controlnet": { + "args": [ + { + "image": control_image, + "weight": control_alpha, + "module": controlnet_filter, + "model": controlnet_model, + "resize_mode": "Crop and Resize", + "threshold_a": 50, + "threshold_b": 130, + } + ] + } + } + + operation_to_apply = "img2img" if init_image else "txt2img" + + stream_image_progress = webui_opts.get("live_previews_enable", False) + + progress_thread = Thread( + target=image_progress_thread, args=(task_id, callback, stream_image_progress, num_outputs, num_inference_steps) + ) + progress_thread.start() + + print(f"task id: {task_id}") + print_request(operation_to_apply, cmd) + + res = webui_post(f"/sdapi/v1/{operation_to_apply}", json=cmd) + if res.status_code == 200: + res = res.json() + else: + raise Exception( + "The engine failed while generating this image. Please check the logs in the command-line window for more details." + ) + + import json + + print(json.loads(res["info"])["infotexts"]) + + images = res["images"] + if output_type == "pil": + images = [base64_str_to_img(img) for img in images] + elif output_type == "base64": + images = [base64_buffer_to_base64_img(img) for img in images] + + return images + + +def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"): + """ + * context: Context + * images: str or PIL.Image or list of str/PIL.Image - image to filter. if a string is passed, it needs to be a base64-encoded image + * filters: filter_type (string) or list of strings + * filter_params: dict + + returns: [PIL.Image] - list of filtered images + """ + images = images if isinstance(images, list) else [images] + filters = filters if isinstance(filters, list) else [filters] + + if "nsfw_checker" in filters: + filters.remove("nsfw_checker") # handled by ED directly + + args = {} + controlnet_filters = [] + + print(filter_params) + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + if filter_name == "gfpgan": + args["gfpgan_visibility"] = 1 + + if filter_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + args["upscaler_1"] = params.get("upscaler", "RealESRGAN_x4plus") + args["upscaling_resize"] = params.get("scale", 4) + + if args["upscaler_1"] == "RealESRGAN_x4plus": + args["upscaler_1"] = "R-ESRGAN 4x+" + elif args["upscaler_1"] == "RealESRGAN_x4plus_anime_6B": + args["upscaler_1"] = "R-ESRGAN 4x+ Anime6B" + + if filter_name == "codeformer": + args["codeformer_visibility"] = 1 + args["codeformer_weight"] = params.get("codeformer_fidelity", 0.5) + + if filter_name.startswith("controlnet_"): + filter_name = convert_ED_controlnet_filter_name(filter_name) + controlnet_filters.append(filter_name) + + print(f"filtering {len(images)} images with {args}. {controlnet_filters=}") + + if len(filters) > len(controlnet_filters): + filtered_images = extra_batch_images(images, input_type=input_type, **args) + else: + filtered_images = images + + for filter_name in controlnet_filters: + filtered_images = controlnet_filter(filtered_images, module=filter_name, input_type=input_type) + + return filtered_images + + +def get_url(): + return f"//{WEBUI_HOST}:{WEBUI_PORT}/?__theme=dark" + + +def stop_rendering(context): + try: + res = webui_post("/sdapi/v1/interrupt") + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + print(f"Error interrupting webui: {e}") + + +def refresh_models(): + def make_refresh_call(type): + try: + webui_post(f"/sdapi/v1/refresh-{type}") + except: + pass + + try: + for type in ("checkpoints", "vae"): + t = Thread(target=make_refresh_call, args=(type,)) + t.start() + except Exception as e: + print(f"Error refreshing models: {e}") + + +def list_controlnet_filters(): + return [ + "openpose", + "openpose_face", + "openpose_faceonly", + "openpose_hand", + "openpose_full", + "animal_openpose", + "densepose_parula (black bg & blue torso)", + "densepose (pruple bg & purple torso)", + "dw_openpose_full", + "mediapipe_face", + "instant_id_face_keypoints", + "InsightFace+CLIP-H (IPAdapter)", + "InsightFace (InstantID)", + "canny", + "mlsd", + "scribble_hed", + "scribble_hedsafe", + "scribble_pidinet", + "scribble_pidsafe", + "scribble_xdog", + "softedge_hed", + "softedge_hedsafe", + "softedge_pidinet", + "softedge_pidsafe", + "softedge_teed", + "normal_bae", + "depth_midas", + "normal_midas", + "depth_zoe", + "depth_leres", + "depth_leres++", + "depth_anything_v2", + "depth_anything", + "depth_hand_refiner", + "depth_marigold", + "lineart_coarse", + "lineart_realistic", + "lineart_anime", + "lineart_standard (from white bg & black line)", + "lineart_anime_denoise", + "reference_adain", + "reference_only", + "reference_adain+attn", + "tile_colorfix", + "tile_resample", + "tile_colorfix+sharp", + "CLIP-ViT-H (IPAdapter)", + "CLIP-G (Revision)", + "CLIP-G (Revision ignore prompt)", + "CLIP-ViT-bigG (IPAdapter)", + "InsightFace+CLIP-H (IPAdapter)", + "inpaint_only", + "inpaint_only+lama", + "inpaint_global_harmonious", + "seg_ufade20k", + "seg_ofade20k", + "seg_anime_face", + "seg_ofcoco", + "shuffle", + "segment", + "invert (from white bg & black line)", + "threshold", + "t2ia_sketch_pidi", + "t2ia_color_grid", + "recolor_intensity", + "recolor_luminance", + "blur_gaussian", + ] + + +def controlnet_filter(images, module="none", processor_res=512, threshold_a=64, threshold_b=64, input_type="pil"): + if input_type == "pil": + images = [img_to_base64_str(x) for x in images] + + payload = { + "controlnet_module": module, + "controlnet_input_images": images, + "controlnet_processor_res": processor_res, + "controlnet_threshold_a": threshold_a, + "controlnet_threshold_b": threshold_b, + } + res = webui_post("/controlnet/detect", json=payload) + res = res.json() + filtered_images = res["images"] + + if input_type == "pil": + filtered_images = [base64_str_to_img(img) for img in filtered_images] + elif input_type == "base64": + filtered_images = [base64_buffer_to_base64_img(img) for img in filtered_images] + + return filtered_images + + +def image_progress_thread(task_id, callback, stream_image_progress, total_images, total_steps): + from PIL import Image + + last_preview_id = -1 + + EMPTY_IMAGE = Image.new("RGB", (1, 1)) + + while True: + res = webui_post( + f"/internal/progress", + json={"id_task": task_id, "live_preview": stream_image_progress, "id_live_preview": last_preview_id}, + ) + if res.status_code == 200: + res = res.json() + else: + raise RuntimeError(f"Unexpected progress response. Status code: {res.status_code}. Res: {res.text}") + + last_preview_id = res["id_live_preview"] + + if res["progress"] is not None: + step_num = int(res["progress"] * total_steps) + + if res["live_preview"] is not None: + img = res["live_preview"] + img = base64_str_to_img(img) + images = [EMPTY_IMAGE] * total_images + images[0] = img + else: + images = None + + callback(images, step_num) + + if res["completed"] == True: + print("Complete!") + break + + time.sleep(0.5) + + +def webui_get(uri, *args, **kwargs): + url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}" + return requests.get(url, *args, **kwargs) + + +def webui_post(uri, *args, **kwargs): + url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}" + return requests.post(url, *args, **kwargs) + + +def print_request(operation_to_apply, args): + args = deepcopy(args) + if "init_images" in args: + args["init_images"] = ["img" for _ in args["init_images"]] + if "mask" in args: + args["mask"] = "mask_img" + + controlnet_args = args.get("alwayson_scripts", {}).get("controlnet", {}).get("args", []) + if controlnet_args: + controlnet_args[0]["image"] = "control_image" + + print(f"operation: {operation_to_apply}, args: {args}") + + +def auto1111_hash(file_path): + import hashlib + + with open(file_path, "rb") as f: + f.seek(0x100000) + b = f.read(0x10000) + return hashlib.sha256(b).hexdigest()[:8] + + +def extra_batch_images( + images, # list of PIL images + name_list=None, # list of image names + resize_mode=0, + show_extras_results=True, + gfpgan_visibility=0, + codeformer_visibility=0, + codeformer_weight=0, + upscaling_resize=2, + upscaling_resize_w=512, + upscaling_resize_h=512, + upscaling_crop=True, + upscaler_1="None", + upscaler_2="None", + extras_upscaler_2_visibility=0, + upscale_first=False, + use_async=False, + input_type="pil", +): + if name_list is not None: + if len(name_list) != len(images): + raise RuntimeError("len(images) != len(name_list)") + else: + name_list = [f"image{i + 1:05}" for i in range(len(images))] + + if input_type == "pil": + images = [img_to_base64_str(x) for x in images] + + image_list = [] + for name, image in zip(name_list, images): + image_list.append({"data": image, "name": name}) + + payload = { + "resize_mode": resize_mode, + "show_extras_results": show_extras_results, + "gfpgan_visibility": gfpgan_visibility, + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + "upscaling_resize": upscaling_resize, + "upscaling_resize_w": upscaling_resize_w, + "upscaling_resize_h": upscaling_resize_h, + "upscaling_crop": upscaling_crop, + "upscaler_1": upscaler_1, + "upscaler_2": upscaler_2, + "extras_upscaler_2_visibility": extras_upscaler_2_visibility, + "upscale_first": upscale_first, + "imageList": image_list, + } + + res = webui_post("/sdapi/v1/extra-batch-images", json=payload) + if res.status_code == 200: + res = res.json() + else: + raise Exception( + "The engine failed while filtering this image. Please check the logs in the command-line window for more details." + ) + + images = res["images"] + + if input_type == "pil": + images = [base64_str_to_img(img) for img in images] + elif input_type == "base64": + images = [base64_buffer_to_base64_img(img) for img in images] + + return images + + +def base64_buffer_to_base64_img(img): + output_format = webui_opts.get("samples_format", "jpeg") + mime_type = f"image/{output_format.lower()}" + return f"data:{mime_type};base64," + img + + +def convert_ED_sampler_names(sampler_name): + name_mapping = { + "dpmpp_2m": "DPM++ 2M", + "dpmpp_sde": "DPM++ SDE", + "dpmpp_2m_sde": "DPM++ 2M SDE", + "dpmpp_2m_sde_heun": "DPM++ 2M SDE Heun", + "dpmpp_2s_a": "DPM++ 2S a", + "dpmpp_3m_sde": "DPM++ 3M SDE", + "euler_a": "Euler a", + "euler": "Euler", + "lms": "LMS", + "heun": "Heun", + "dpm2": "DPM2", + "dpm2_a": "DPM2 a", + "dpm_fast": "DPM fast", + "dpm_adaptive": "DPM adaptive", + "restart": "Restart", + "heun_pp2": "HeunPP2", + "ipndm": "IPNDM", + "ipndm_v": "IPNDM_V", + "deis": "DEIS", + "ddim": "DDIM", + "ddim_cfgpp": "DDIM CFG++", + "plms": "PLMS", + "unipc": "UniPC", + "lcm": "LCM", + "ddpm": "DDPM", + "forge_flux_realistic": "[Forge] Flux Realistic", + "forge_flux_realistic_slow": "[Forge] Flux Realistic (Slow)", + # deprecated samplers in 3.5 + "dpm_solver_stability": None, + "unipc_snr": None, + "unipc_tu": None, + "unipc_snr_2": None, + "unipc_tu_2": None, + "unipc_tq": None, + } + return name_mapping.get(sampler_name) + + +def convert_ED_controlnet_filter_name(filter): + if filter is None: + return None + + def cn(n): + if n.startswith("controlnet_"): + return n[len("controlnet_") :] + return n + + mapping = { + "controlnet_scribble_hedsafe": None, + "controlnet_scribble_pidsafe": None, + "controlnet_softedge_pidsafe": "controlnet_softedge_pidisafe", + "controlnet_normal_bae": "controlnet_normalbae", + "controlnet_segment": None, + } + if isinstance(filter, list): + return [cn(mapping.get(f, f)) for f in filter] + return cn(mapping.get(filter, filter)) diff --git a/ui/easydiffusion/easydb/schemas.py b/ui/easydiffusion/easydb/schemas.py index 68bc04e2..c6b7337f 100644 --- a/ui/easydiffusion/easydb/schemas.py +++ b/ui/easydiffusion/easydb/schemas.py @@ -33,4 +33,3 @@ class Bucket(BucketBase): class Config: orm_mode = True - diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 26f05242..e904a315 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -8,8 +8,7 @@ from easydiffusion import app from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db -from sdkit.models.model_loader.controlnet_filters import filters as cn_filters +from sdkit.models import scan_model, download_model, get_model_info_from_db from sdkit.utils import hash_file_quick from sdkit.models.model_loader.embeddings import get_embedding_token @@ -25,15 +24,15 @@ KNOWN_MODEL_TYPES = [ "controlnet", ] MODEL_EXTENSIONS = { - "stable-diffusion": [".ckpt", ".safetensors"], - "vae": [".vae.pt", ".ckpt", ".safetensors"], - "hypernetwork": [".pt", ".safetensors"], + "stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"], + "vae": [".vae.pt", ".ckpt", ".safetensors", ".sft"], + "hypernetwork": [".pt", ".safetensors", ".sft"], "gfpgan": [".pth"], "realesrgan": [".pth"], - "lora": [".ckpt", ".safetensors", ".pt"], + "lora": [".ckpt", ".safetensors", ".sft", ".pt"], "codeformer": [".pth"], - "embeddings": [".pt", ".bin", ".safetensors"], - "controlnet": [".pth", ".safetensors"], + "embeddings": [".pt", ".bin", ".safetensors", ".sft"], + "controlnet": [".pth", ".safetensors", ".sft"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -51,6 +50,16 @@ DEFAULT_MODELS = { ], } MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] +ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility + "stable-diffusion": "Stable-diffusion", + "vae": "VAE", + "hypernetwork": "hypernetworks", + "codeformer": "Codeformer", + "gfpgan": "GFPGAN", + "realesrgan": "RealESRGAN", + "lora": "Lora", + "controlnet": "ControlNet", +} known_models = {} @@ -63,6 +72,7 @@ def init(): def load_default_models(context: Context): from easydiffusion import runtime + from easydiffusion.backend_manager import backend runtime.set_vram_optimizations(context) @@ -70,7 +80,7 @@ def load_default_models(context: Context): for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False) try: - load_model( + backend.load_model( context, model_type, scan_model=context.model_paths[model_type] != None @@ -92,9 +102,11 @@ def load_default_models(context: Context): def unload_all(context: Context): + from easydiffusion.backend_manager import backend + for model_type in KNOWN_MODEL_TYPES: - unload_model(context, model_type) - if model_type in context.model_load_errors: + backend.unload_model(context, model_type) + if hasattr(context, "model_load_errors") and model_type in context.model_load_errors: del context.model_load_errors[model_type] @@ -119,33 +131,33 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dir = os.path.join(app.MODELS_DIR, model_type) if not model_name: # When None try user configured model. # config = getConfig() if "model" in config and model_type in config["model"]: model_name = config["model"][model_type] - if model_name: - # Check models directory - model_path = os.path.join(model_dir, model_name) - if os.path.exists(model_path): - return model_path - for model_extension in model_extensions: - if os.path.exists(model_path + model_extension): - return model_path + model_extension - if os.path.exists(model_name + model_extension): - return os.path.abspath(model_name + model_extension) + for model_dir in get_model_dirs(model_type): + if model_name: + # Check models directory + model_path = os.path.join(model_dir, model_name) + if os.path.exists(model_path): + return model_path + for model_extension in model_extensions: + if os.path.exists(model_path + model_extension): + return model_path + model_extension + if os.path.exists(model_name + model_extension): + return os.path.abspath(model_name + model_extension) - # Can't find requested model, check the default paths. - if model_type == "stable-diffusion" and not fail_if_not_found: - for default_model in default_models: - default_model_path = os.path.join(model_dir, default_model["file_name"]) - if os.path.exists(default_model_path): - if model_name is not None: - log.warn( - f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" - ) - return default_model_path + # Can't find requested model, check the default paths. + if model_type == "stable-diffusion" and not fail_if_not_found: + for default_model in default_models: + default_model_path = os.path.join(model_dir, default_model["file_name"]) + if os.path.exists(default_model_path): + if model_name is not None: + log.warn( + f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" + ) + return default_model_path if model_name and fail_if_not_found: raise FileNotFoundError( @@ -154,6 +166,8 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []): + from easydiffusion.backend_manager import backend + models_to_reload = { model_type: path for model_type, path in models_data.model_paths.items() @@ -175,7 +189,7 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req - action_fn = unload_model if context.model_paths[model_type] is None else load_model + action_fn = backend.unload_model if context.model_paths[model_type] is None else backend.load_model extra_params = models_data.model_params.get(model_type, {}) try: action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already @@ -183,14 +197,27 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models del context.model_load_errors[model_type] except Exception as e: log.exception(e) - if action_fn == load_model: + if action_fn == backend.load_model: context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks def resolve_model_paths(models_data: ModelsData): + from easydiffusion.backend_manager import backend + + cn_filters = backend.list_controlnet_filters() + model_paths = models_data.model_paths + skip_models = cn_filters + [ + "latent_upscaler", + "nsfw_checker", + "esrgan_4x", + "lanczos", + "nearest", + "scunet", + "swinir", + ] + for model_type in model_paths: - skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"] if model_type in skip_models: # doesn't use model paths continue if model_type == "codeformer" and model_paths[model_type]: @@ -225,7 +252,8 @@ def download_default_models_if_necessary(): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): - model_path = os.path.join(app.MODELS_DIR, model_type, file_name) + model_dir = get_model_dirs(model_type)[0] + model_path = os.path.join(model_dir, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] other_models_exist = any_model_exists(model_type) and skip_if_others_exist @@ -245,21 +273,23 @@ def migrate_legacy_model_location(): file_name = model["file_name"] legacy_path = os.path.join(app.SD_DIR, file_name) if os.path.exists(legacy_path): - shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name)) + model_dir = get_model_dirs(model_type)[0] + shutil.move(legacy_path, os.path.join(model_dir, file_name)) def any_model_exists(model_type: str) -> bool: extensions = MODEL_EXTENSIONS.get(model_type, []) - for ext in extensions: - if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)): - return True + for model_dir in get_model_dirs(model_type): + for ext in extensions: + if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): + return True return False def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: - model_dir_path = os.path.join(app.MODELS_DIR, model_type) + model_dir_path = get_model_dirs(model_type)[0] try: os.makedirs(model_dir_path, exist_ok=True) @@ -282,9 +312,11 @@ def make_model_folders(): help_file_name = f"Place your {model_type} model files here.txt" help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' - - with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: - f.write(help_file_contents) + try: + with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: + f.write(help_file_contents) + except Exception as e: + log.exception(e) def is_malicious_model(file_path): @@ -320,6 +352,10 @@ def is_malicious_model(file_path): def getModels(scan_for_malicious: bool = True): + from easydiffusion.backend_manager import backend + + backend.refresh_models() + models = { "options": { "stable-diffusion": [], @@ -329,19 +365,19 @@ def getModels(scan_for_malicious: bool = True): "codeformer": [{"codeformer": "CodeFormer"}], "embeddings": [], "controlnet": [ - {"control_v11p_sd15_canny": "Canny (*)"}, - {"control_v11p_sd15_openpose": "OpenPose (*)"}, - {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, - {"control_v11f1p_sd15_depth": "Depth (*)"}, - {"control_v11p_sd15_scribble": "Scribble"}, - {"control_v11p_sd15_softedge": "Soft Edge"}, - {"control_v11p_sd15_inpaint": "Inpaint"}, - {"control_v11p_sd15_lineart": "Line Art"}, - {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, - {"control_v11p_sd15_mlsd": "Straight Lines"}, - {"control_v11p_sd15_seg": "Segment"}, - {"control_v11e_sd15_shuffle": "Shuffle"}, - {"control_v11f1e_sd15_tile": "Tile"}, + # {"control_v11p_sd15_canny": "Canny (*)"}, + # {"control_v11p_sd15_openpose": "OpenPose (*)"}, + # {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, + # {"control_v11f1p_sd15_depth": "Depth (*)"}, + # {"control_v11p_sd15_scribble": "Scribble"}, + # {"control_v11p_sd15_softedge": "Soft Edge"}, + # {"control_v11p_sd15_inpaint": "Inpaint"}, + # {"control_v11p_sd15_lineart": "Line Art"}, + # {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, + # {"control_v11p_sd15_mlsd": "Straight Lines"}, + # {"control_v11p_sd15_seg": "Segment"}, + # {"control_v11e_sd15_shuffle": "Shuffle"}, + # {"control_v11f1e_sd15_tile": "Tile"}, ], }, } @@ -356,6 +392,9 @@ def getModels(scan_for_malicious: bool = True): tree = list(default_entries) + if not os.path.exists(directory): + return tree + for entry in sorted( os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), @@ -378,6 +417,8 @@ def getModels(scan_for_malicious: bool = True): model_id = entry.name[: -len(matching_suffix)] if callable(nameFilter): model_id = nameFilter(model_id) + if model_id is None: + continue model_exists = False for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models @@ -398,17 +439,18 @@ def getModels(scan_for_malicious: bool = True): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) - models_dir = os.path.join(app.MODELS_DIR, model_type) - if not os.path.exists(models_dir): - os.makedirs(models_dir) + models_dirs = get_model_dirs(model_type) + if not os.path.exists(models_dirs[0]): + os.makedirs(models_dirs[0]) - try: - default_tree = models["options"].get(model_type, []) - models["options"][model_type] = scan_directory( - models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter - ) - except MaliciousModelException as e: - models["scan-error"] = str(e) + for model_dir in models_dirs: + try: + default_tree = models["options"].get(model_type, []) + models["options"][model_type] = scan_directory( + model_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter + ) + except MaliciousModelException as e: + models["scan-error"] = str(e) if scan_for_malicious: log.info(f"[green]Scanning all model folders for models...[/]") @@ -416,7 +458,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="stable-diffusion") listModels(model_type="vae") listModels(model_type="hypernetwork") - listModels(model_type="gfpgan") + listModels(model_type="gfpgan", nameFilter=lambda x: (x if "gfpgan" in x.lower() else None)) listModels(model_type="lora") listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="controlnet") @@ -425,3 +467,20 @@ def getModels(scan_for_malicious: bool = True): log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") return models + + +def get_model_dirs(model_type: str, base_dir=None): + "Returns the possible model directory paths for the given model type. Mainly used for WebUI compatibility" + + if base_dir is None: + base_dir = app.MODELS_DIR + + dirs = [os.path.join(base_dir, model_type)] + + if model_type in ALTERNATE_FOLDER_NAMES: + alt_dir = ALTERNATE_FOLDER_NAMES[model_type] + alt_dir = os.path.join(base_dir, alt_dir) + if os.path.exists(alt_dir) and os.path.isdir(alt_dir): + dirs.append(alt_dir) + + return dirs diff --git a/ui/easydiffusion/package_manager.py b/ui/easydiffusion/package_manager.py index c28a58a1..9df7f1fd 100644 --- a/ui/easydiffusion/package_manager.py +++ b/ui/easydiffusion/package_manager.py @@ -3,8 +3,6 @@ import os import platform from importlib.metadata import version as pkg_version -from sdkit.utils import log - from easydiffusion import app # future home of scripts/check_modules.py @@ -50,6 +48,8 @@ def is_installed(module_name) -> bool: def install(module_name): + from easydiffusion.utils import log + if is_installed(module_name): log.info(f"{module_name} has already been installed!") return @@ -79,6 +79,8 @@ def install(module_name): def uninstall(module_name): + from easydiffusion.utils import log + if not is_installed(module_name): log.info(f"{module_name} hasn't been installed!") return diff --git a/ui/easydiffusion/runtime.py b/ui/easydiffusion/runtime.py index 78d90f60..d85839c0 100644 --- a/ui/easydiffusion/runtime.py +++ b/ui/easydiffusion/runtime.py @@ -1,4 +1,5 @@ """ +(OUTDATED DOC) A runtime that runs on a specific device (in a thread). It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context. @@ -6,45 +7,38 @@ It can run various tasks like image generation, image filtering, model merge etc This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function. """ -from easydiffusion import device_manager -from easydiffusion.utils import log -from sdkit import Context -from sdkit.utils import get_device_usage - -context = Context() # thread-local -""" -runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc -""" +context = None def init(device): """ Initializes the fields that will be bound to this runtime's context, and sets the current torch device """ + + global context + + from easydiffusion import device_manager + from easydiffusion.backend_manager import backend + from easydiffusion.app import getConfig + + context = backend.create_context() + context.stop_processing = False context.temp_images = {} context.partial_x_samples = None context.model_load_errors = {} context.enable_codeformer = True - from easydiffusion import app - - app_config = app.getConfig() - context.test_diffusers = app_config.get("use_v3_engine", True) - - log.info("Device usage during initialization:") - get_device_usage(device, log_info=True, process_usage_only=False) - device_manager.device_init(context, device) -def set_vram_optimizations(context: Context): - from easydiffusion import app +def set_vram_optimizations(context): + from easydiffusion.app import getConfig - config = app.getConfig() + config = getConfig() vram_usage_level = config.get("vram_usage_level", "balanced") - if vram_usage_level != context.vram_usage_level: + if hasattr(context, "vram_usage_level") and vram_usage_level != context.vram_usage_level: context.vram_usage_level = vram_usage_level return True diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index a251ede6..7a82680e 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -2,6 +2,7 @@ Notes: async endpoints always run on the main thread. Without they run on the thread pool. """ + import datetime import mimetypes import os @@ -20,6 +21,7 @@ from easydiffusion.types import ( OutputFormatData, SaveToDiskData, convert_legacy_render_req_to_new, + convert_legacy_controlnet_filter_name, ) from easydiffusion.utils import log from fastapi import FastAPI, HTTPException @@ -67,7 +69,9 @@ class SetAppConfigRequest(BaseModel, extra=Extra.allow): listen_to_network: bool = None listen_port: int = None use_v3_engine: bool = True + backend: str = "ed_diffusers" models_dir: str = None + vram_usage_level: str = "balanced" def init(): @@ -155,6 +159,12 @@ def init(): def shutdown_event(): # Signal render thread to close on shutdown task_manager.current_state_error = SystemExit("Application shutting down.") + @server_api.on_event("startup") + def start_event(): + from easydiffusion.app import open_browser + + open_browser() + # API implementations def set_app_config_internal(req: SetAppConfigRequest): @@ -176,8 +186,10 @@ def set_app_config_internal(req: SetAppConfigRequest): config["net"] = {} config["net"]["listen_port"] = int(req.listen_port) - config["use_v3_engine"] = req.use_v3_engine + config["use_v3_engine"] = req.backend == "ed_diffusers" + config["backend"] = req.backend config["models_dir"] = req.models_dir + config["vram_usage_level"] = req.vram_usage_level for property, property_value in req.dict().items(): if property_value is not None and property not in req.__fields__ and property not in PROTECTED_CONFIG_KEYS: @@ -216,6 +228,8 @@ def read_web_data_internal(key: str = None, **kwargs): return JSONResponse(config, headers=NOCACHE_HEADERS) elif key == "system_info": + from easydiffusion.backend_manager import backend + config = app.getConfig() output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) @@ -226,6 +240,7 @@ def read_web_data_internal(key: str = None, **kwargs): "default_output_dir": output_dir, "enforce_output_dir": ("force_save_path" in config), "enforce_output_metadata": ("force_save_metadata" in config), + "backend_url": backend.get_url(), } system_info["devices"]["config"] = config.get("render_devices", "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) @@ -309,6 +324,15 @@ def filter_internal(req: dict): output_format: OutputFormatData = OutputFormatData.parse_obj(req) save_data: SaveToDiskData = SaveToDiskData.parse_obj(req) + filter_req.filter = convert_legacy_controlnet_filter_name(filter_req.filter) + + for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if models_data.model_paths.get(model_name): + if model_name not in filter_req.filter_params: + filter_req.filter_params[model_name] = {} + + filter_req.filter_params[model_name]["upscaler"] = models_data.model_paths[model_name] + # enqueue the task task = FilterTask(filter_req, task_data, models_data, output_format, save_data) return enqueue_task(task) @@ -342,15 +366,13 @@ def model_merge_internal(req: dict): mergeReq: MergeRequest = MergeRequest.parse_obj(req) + sd_model_dir = model_manager.get_model_dir("stable-diffusion")[0] + merge_models( model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), mergeReq.ratio, - os.path.join( - app.MODELS_DIR, - "stable-diffusion", - filename_regex.sub("_", mergeReq.out_path), - ), + os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)), mergeReq.use_fp16, ) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 699b4494..38cba0b6 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -4,6 +4,7 @@ Notes: Use weak_thread_data to store all other data using weak keys. This will allow for garbage collection after the thread dies. """ + import json import traceback @@ -19,7 +20,6 @@ import torch from easydiffusion import device_manager from easydiffusion.tasks import Task from easydiffusion.utils import log -from sdkit.utils import gc THREAD_NAME_PREFIX = "" ERR_LOCK_FAILED = " failed to acquire lock within timeout." @@ -233,6 +233,8 @@ def thread_render(device): global current_state, current_state_error from easydiffusion import model_manager, runtime + from easydiffusion.backend_manager import backend + from requests import ConnectionError try: runtime.init(device) @@ -244,8 +246,17 @@ def thread_render(device): } current_state = ServerStates.LoadingModel - model_manager.load_default_models(runtime.context) + while True: + try: + if backend.ping(timeout=1): + break + + time.sleep(1) + except (TimeoutError, ConnectionError): + time.sleep(1) + + model_manager.load_default_models(runtime.context) current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) @@ -291,7 +302,6 @@ def thread_render(device): task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) finally: - gc(runtime.context) task.lock.release() keep_task_alive(task) diff --git a/ui/easydiffusion/tasks/filter_images.py b/ui/easydiffusion/tasks/filter_images.py index 7d3d1326..67572a8a 100644 --- a/ui/easydiffusion/tasks/filter_images.py +++ b/ui/easydiffusion/tasks/filter_images.py @@ -5,9 +5,7 @@ import time from numpy import base_repr -from sdkit.filter import apply_filters -from sdkit.models import load_model -from sdkit.utils import img_to_base64_str, get_image, log, save_images +from sdkit.utils import img_to_base64_str, log, save_images, base64_str_to_img from easydiffusion import model_manager, runtime from easydiffusion.types import ( @@ -19,6 +17,7 @@ from easydiffusion.types import ( TaskData, GenerateImageRequest, ) +from easydiffusion.utils import filter_nsfw from easydiffusion.utils.save_utils import format_folder_name from .task import Task @@ -47,7 +46,9 @@ class FilterTask(Task): # convert to multi-filter format, if necessary if isinstance(req.filter, str): - req.filter_params = {req.filter: req.filter_params} + if req.filter not in req.filter_params: + req.filter_params = {req.filter: req.filter_params} + req.filter = [req.filter] if not isinstance(req.image, list): @@ -57,6 +58,7 @@ class FilterTask(Task): "Runs the image filtering task on the assigned thread" from easydiffusion import app + from easydiffusion.backend_manager import backend context = runtime.context @@ -66,15 +68,24 @@ class FilterTask(Task): print_task_info(self.request, self.models_data, self.output_format, self.save_data) - if isinstance(self.request.image, list): - images = [get_image(img) for img in self.request.image] - else: - images = get_image(self.request.image) - - images = filter_images(context, images, self.request.filter, self.request.filter_params) + has_nsfw_filter = "nsfw_filter" in self.request.filter output_format = self.output_format + backend.set_options( + context, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + ) + + images = backend.filter_images( + context, self.request.image, self.request.filter, self.request.filter_params, input_type="base64" + ) + + if has_nsfw_filter: + images = filter_nsfw(images) + if self.save_data.save_to_disk_path is not None: app_config = app.getConfig() folder_format = app_config.get("folder_format", "$id") @@ -85,8 +96,9 @@ class FilterTask(Task): save_dir_path = os.path.join( self.save_data.save_to_disk_path, format_folder_name(folder_format, dummy_req, self.task_data) ) + images_pil = [base64_str_to_img(img) for img in images] save_images( - images, + images_pil, save_dir_path, file_name=img_id, output_format=output_format.output_format, @@ -94,13 +106,6 @@ class FilterTask(Task): output_lossless=output_format.output_lossless, ) - images = [ - img_to_base64_str( - img, output_format.output_format, output_format.output_quality, output_format.output_lossless - ) - for img in images - ] - res = FilterImageResponse(self.request, self.models_data, images=images) res = res.json() self.buffer_queue.put(json.dumps(res)) @@ -110,46 +115,6 @@ class FilterTask(Task): self.response = res -def filter_images(context, images, filters, filter_params={}): - filters = filters if isinstance(filters, list) else [filters] - - for filter_name in filters: - params = filter_params.get(filter_name, {}) - - previous_state = before_filter(context, filter_name, params) - - try: - images = apply_filters(context, filter_name, images, **params) - finally: - after_filter(context, filter_name, params, previous_state) - - return images - - -def before_filter(context, filter_name, filter_params): - if filter_name == "codeformer": - from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use - - default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] - prev_realesrgan_path = None - - upscale_faces = filter_params.get("upscale_faces", False) - if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: - prev_realesrgan_path = context.model_paths.get("realesrgan") - context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") - load_model(context, "realesrgan") - - return prev_realesrgan_path - - -def after_filter(context, filter_name, filter_params, previous_state): - if filter_name == "codeformer": - prev_realesrgan_path = previous_state - if prev_realesrgan_path: - context.model_paths["realesrgan"] = prev_realesrgan_path - load_model(context, "realesrgan") - - def print_task_info( req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData ): diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index 7d8fbc5e..2a46e6e6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -2,26 +2,23 @@ import json import pprint import queue import time +from PIL import Image from easydiffusion import model_manager, runtime from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData, SaveToDiskData from easydiffusion.types import Image as ResponseImage -from easydiffusion.types import GenerateImageResponse, RenderTaskData, UserInitiatedStop -from easydiffusion.utils import get_printable_request, log, save_images_to_disk -from sdkit.generate import generate_images +from easydiffusion.types import GenerateImageResponse, RenderTaskData +from easydiffusion.utils import get_printable_request, log, save_images_to_disk, filter_nsfw from sdkit.utils import ( - diffusers_latent_samples_to_images, - gc, img_to_base64_str, + base64_str_to_img, img_to_buffer, - latent_samples_to_images, resize_img, get_image, log, ) from .task import Task -from .filter_images import filter_images class RenderTask(Task): @@ -51,15 +48,13 @@ class RenderTask(Task): "Runs the image generation task on the assigned thread" from easydiffusion import task_manager, app + from easydiffusion.backend_manager import backend context = runtime.context config = app.getConfig() if config.get("block_nsfw", False): # override if set on the server self.task_data.block_nsfw = True - if "nsfw_checker" not in self.task_data.filters: - self.task_data.filters.append("nsfw_checker") - self.models_data.model_paths["nsfw_checker"] = "nsfw_checker" def step_callback(): task_manager.keep_task_alive(self) @@ -68,7 +63,7 @@ class RenderTask(Task): if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance( self.error, StopAsyncIteration ): - context.stop_processing = True + backend.stop_rendering(context) if isinstance(task_manager.current_state_error, StopAsyncIteration): self.error = task_manager.current_state_error task_manager.current_state_error = None @@ -78,11 +73,7 @@ class RenderTask(Task): model_manager.resolve_model_paths(self.models_data) models_to_force_reload = [] - if ( - runtime.set_vram_optimizations(context) - or self.has_param_changed(context, "clip_skip") - or self.trt_needs_reload(context) - ): + if runtime.set_vram_optimizations(context) or self.has_param_changed(context, "clip_skip"): models_to_force_reload.append("stable-diffusion") model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload) @@ -99,10 +90,11 @@ class RenderTask(Task): self.buffer_queue, self.temp_images, step_callback, + self, ) def has_param_changed(self, context, param_name): - if not context.test_diffusers: + if not getattr(context, "test_diffusers", False): return False if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: return True @@ -111,29 +103,6 @@ class RenderTask(Task): new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) return model["params"].get(param_name) != new_val - def trt_needs_reload(self, context): - if not context.test_diffusers: - return False - if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: - return True - - model = context.models["stable-diffusion"] - - # curr_convert_to_trt = model["params"].get("convert_to_tensorrt") - new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False) - - pipe = model["default"] - is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr( - pipe.unet, "_allocate_trt_buffers_backup" - ) - if new_convert_to_trt and not is_trt_loaded: - return True - - curr_build_config = model["params"].get("trt_build_config") - new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {}) - - return new_convert_to_trt and curr_build_config != new_build_config - def make_images( context, @@ -145,12 +114,21 @@ def make_images( data_queue: queue.Queue, task_temp_images: list, step_callback, + task, ): - context.stop_processing = False print_task_info(req, task_data, models_data, output_format, save_data) images, seeds = make_images_internal( - context, req, task_data, models_data, output_format, save_data, data_queue, task_temp_images, step_callback + context, + req, + task_data, + models_data, + output_format, + save_data, + data_queue, + task_temp_images, + step_callback, + task, ) res = GenerateImageResponse( @@ -170,7 +148,9 @@ def print_task_info( output_format: OutputFormatData, save_data: SaveToDiskData, ): - req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[") + req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace( + "[", "\[" + ) task_str = pprint.pformat(task_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[") @@ -178,7 +158,7 @@ def print_task_info( log.info(f"request: {req_str}") log.info(f"task data: {task_str}") - # log.info(f"models data: {models_data}") + log.info(f"models data: {models_data}") log.info(f"output format: {output_format}") log.info(f"save data: {save_data}") @@ -193,26 +173,41 @@ def make_images_internal( data_queue: queue.Queue, task_temp_images: list, step_callback, + task, ): - images, user_stopped = generate_images_internal( + from easydiffusion.backend_manager import backend + + # prep the nsfw_filter + if task_data.block_nsfw: + filter_nsfw([Image.new("RGB", (1, 1))]) # hack - ensures that the model is available + + images = generate_images_internal( context, req, task_data, models_data, + output_format, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval, ) - - gc(context) + user_stopped = isinstance(task.error, StopAsyncIteration) filters, filter_params = task_data.filters, task_data.filter_params - filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images + if len(filters) > 0 and not user_stopped: + filtered_images = backend.filter_images(context, images, filters, filter_params, input_type="base64") + else: + filtered_images = images + + if task_data.block_nsfw: + filtered_images = filter_nsfw(filtered_images) if save_data.save_to_disk_path is not None: - save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data) + images_pil = [base64_str_to_img(img) for img in images] + filtered_images_pil = [base64_str_to_img(img) for img in filtered_images] + save_images_to_disk(images_pil, filtered_images_pil, req, task_data, models_data, output_format, save_data) seeds = [*range(req.seed, req.seed + len(images))] if task_data.show_only_filtered_image or filtered_images is images: @@ -226,97 +221,43 @@ def generate_images_internal( req: GenerateImageRequest, task_data: RenderTaskData, models_data: ModelsData, + output_format: OutputFormatData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool, stream_image_progress_interval: int, ): - context.temp_images.clear() + from easydiffusion.backend_manager import backend - callback = make_step_callback( + callback = make_step_callback(context, req, task_data, data_queue, task_temp_images, step_callback) + + req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 + + if req.control_image and task_data.control_filter_to_apply: + req.controlnet_filter = task_data.control_filter_to_apply + + if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: + req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 + + backend.set_options( context, - req, - task_data, - data_queue, - task_temp_images, - step_callback, - stream_image_progress, - stream_image_progress_interval, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + vae_tiling=task_data.enable_vae_tiling, + stream_image_progress=stream_image_progress, + stream_image_progress_interval=stream_image_progress_interval, + clip_skip=2 if task_data.clip_skip else 1, ) - try: - if req.init_image is not None and not context.test_diffusers: - req.sampler_name = "ddim" + images = backend.generate_images(context, callback=callback, output_type="base64", **req.dict()) - req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 - - if req.control_image and task_data.control_filter_to_apply: - req.control_image = get_image(req.control_image) - req.control_image = resize_img(req.control_image.convert("RGB"), req.width, req.height, clamp_to_8=True) - req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0] - - if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: - req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 - - if context.test_diffusers: - pipe = context.models["stable-diffusion"]["default"] - if hasattr(pipe.unet, "_allocate_trt_buffers_backup"): - setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup) - delattr(pipe.unet, "_allocate_trt_buffers_backup") - - if hasattr(pipe.unet, "_allocate_trt_buffers"): - convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False) - if convert_to_trt: - pipe.unet.forward = pipe.unet._trt_forward - # pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward - log.info(f"Setting unet.forward to TensorRT") - else: - log.info(f"Not using TensorRT for unet.forward") - pipe.unet.forward = pipe.unet._non_trt_forward - # pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward - setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers) - delattr(pipe.unet, "_allocate_trt_buffers") - - if task_data.enable_vae_tiling: - if hasattr(pipe, "enable_vae_tiling"): - pipe.enable_vae_tiling() - else: - if hasattr(pipe, "disable_vae_tiling"): - pipe.disable_vae_tiling() - - images = generate_images(context, callback=callback, **req.dict()) - user_stopped = False - except UserInitiatedStop: - images = [] - user_stopped = True - if context.partial_x_samples is not None: - if context.test_diffusers: - images = diffusers_latent_samples_to_images(context, context.partial_x_samples) - else: - images = latent_samples_to_images(context, context.partial_x_samples) - finally: - if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: - if not context.test_diffusers: - del context.partial_x_samples - context.partial_x_samples = None - - return images, user_stopped + return images def construct_response(images: list, seeds: list, output_format: OutputFormatData): - return [ - ResponseImage( - data=img_to_base64_str( - img, - output_format.output_format, - output_format.output_quality, - output_format.output_lossless, - ), - seed=seed, - ) - for img, seed in zip(images, seeds) - ] + return [ResponseImage(data=img, seed=seed) for img, seed in zip(images, seeds)] def make_step_callback( @@ -326,53 +267,44 @@ def make_step_callback( data_queue: queue.Queue, task_temp_images: list, step_callback, - stream_image_progress: bool, - stream_image_progress_interval: int, ): + from easydiffusion.backend_manager import backend + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 - def update_temp_img(x_samples, task_temp_images: list): + def update_temp_img(images, task_temp_images: list): partial_images = [] - if context.test_diffusers: - images = diffusers_latent_samples_to_images(context, x_samples) - else: - images = latent_samples_to_images(context, x_samples) + if images is None: + return [] if task_data.block_nsfw: - images = filter_images(context, images, "nsfw_checker") + images = filter_nsfw(images, print_log=False) for i, img in enumerate(images): + img = img.convert("RGB") + img = resize_img(img, req.width, req.height) buf = img_to_buffer(img, output_format="JPEG") - context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"}) del images return partial_images - def on_image_step(x_samples, i, *args): + def on_image_step(images, i, *args): nonlocal last_callback_time - if context.test_diffusers: - context.partial_x_samples = (x_samples, args[0]) - else: - context.partial_x_samples = x_samples - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() progress = {"step": i, "step_time": step_time, "total_steps": n_steps} - if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: - progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images) + if images is not None: + progress["output"] = update_temp_img(images, task_temp_images) data_queue.put(json.dumps(progress)) step_callback() - if context.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") - return on_image_step diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 2da62537..bc0ccabf 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -14,16 +14,19 @@ class GenerateImageRequest(BaseModel): num_outputs: int = 1 num_inference_steps: int = 50 guidance_scale: float = 7.5 + distilled_guidance_scale: float = 3.5 init_image: Any = None init_image_mask: Any = None control_image: Any = None control_alpha: Union[float, List[float]] = None + controlnet_filter: str = None prompt_strength: float = 0.8 preserve_init_image_color_profile: bool = False strict_mask_border: bool = False sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + scheduler_name: str = None hypernetwork_strength: float = 0 lora_alpha: Union[float, List[float]] = 0 tiling: str = None # None, "x", "y", "xy" @@ -213,22 +216,19 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["controlnet"] = old_req.get("use_controlnet_model") model_paths["embeddings"] = old_req.get("use_embeddings_model") - model_paths["gfpgan"] = old_req.get("use_face_correction", "") - model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None + ## ensure that the model name is in the model path + for model_name in ("gfpgan", "codeformer"): + model_paths[model_name] = old_req.get("use_face_correction", "") + model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None - model_paths["codeformer"] = old_req.get("use_face_correction", "") - model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None + for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + model_paths[model_name] = old_req.get("use_upscale", "") + model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None - model_paths["realesrgan"] = old_req.get("use_upscale", "") - model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None - - model_paths["latent_upscaler"] = old_req.get("use_upscale", "") - model_paths["latent_upscaler"] = ( - model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None - ) if "control_filter_to_apply" in old_req: filter_model = old_req["control_filter_to_apply"] model_paths[filter_model] = filter_model + old_req["control_filter_to_apply"] = convert_legacy_controlnet_filter_name(old_req["control_filter_to_apply"]) if old_req.get("block_nsfw"): model_paths["nsfw_checker"] = "nsfw_checker" @@ -244,8 +244,12 @@ def convert_legacy_render_req_to_new(old_req: dict): } # move the filter params - if model_paths["realesrgan"]: - filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))} + for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if model_paths[model_name]: + filter_params[model_name] = { + "upscaler": model_paths[model_name], + "scale": int(old_req.get("upscale_amount", 4)), + } if model_paths["latent_upscaler"]: filter_params["latent_upscaler"] = { "prompt": old_req["prompt"], @@ -264,14 +268,31 @@ def convert_legacy_render_req_to_new(old_req: dict): if old_req.get("block_nsfw"): filters.append("nsfw_checker") - if model_paths["codeformer"]: - filters.append("codeformer") - elif model_paths["gfpgan"]: - filters.append("gfpgan") + for model_name in ("gfpgan", "codeformer"): + if model_paths[model_name]: + filters.append(model_name) + break - if model_paths["realesrgan"]: - filters.append("realesrgan") - elif model_paths["latent_upscaler"]: - filters.append("latent_upscaler") + for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"): + if model_paths[model_name]: + filters.append(model_name) + break return new_req + + +def convert_legacy_controlnet_filter_name(filter): + from easydiffusion.backend_manager import backend + + if filter is None: + return None + + controlnet_filter_names = backend.list_controlnet_filters() + + def apply(f): + return f"controlnet_{f}" if f in controlnet_filter_names else f + + if isinstance(filter, list): + return [apply(f) for f in filter] + + return apply(filter) diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py index a930725b..3be08e79 100644 --- a/ui/easydiffusion/utils/__init__.py +++ b/ui/easydiffusion/utils/__init__.py @@ -7,6 +7,8 @@ from .save_utils import ( save_images_to_disk, get_printable_request, ) +from .nsfw_checker import filter_nsfw + def sha256sum(filename): sha256 = hashlib.sha256() @@ -18,4 +20,3 @@ def sha256sum(filename): sha256.update(data) return sha256.hexdigest() - diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py new file mode 100644 index 00000000..51a684df --- /dev/null +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -0,0 +1,80 @@ +# possibly move this to sdkit in the future +import os + +# mirror of https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/model_quantized.onnx +NSFW_MODEL_URL = ( + "https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/vit-base-nsfw-detector-quantized.onnx" +) +MODEL_HASH_QUICK = "220123559305b1b07b7a0894c3471e34dccd090d71cdf337dd8012f9e40d6c28" + +nsfw_check_model = None + + +def filter_nsfw(images, blur_radius: float = 75, print_log=True): + global nsfw_check_model + + from easydiffusion.model_manager import get_model_dirs + from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick + + import onnxruntime as ort + from PIL import ImageFilter + import numpy as np + + if nsfw_check_model is None: + model_dir = get_model_dirs("nsfw-checker")[0] + model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") + + os.makedirs(model_dir, exist_ok=True) + + if not os.path.exists(model_path) or hash_file_quick(model_path) != MODEL_HASH_QUICK: + download_file(NSFW_MODEL_URL, model_path) + + nsfw_check_model = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + # Preprocess the input image + def preprocess_image(img): + img = img.convert("RGB") + + # config based on based on https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/preprocessor_config.json + # Resize the image + img = img.resize((384, 384)) + + # Normalize the image + img = np.array(img) / 255.0 # Scale pixel values to [0, 1] + mean = np.array([0.5, 0.5, 0.5]) + std = np.array([0.5, 0.5, 0.5]) + img = (img - mean) / std + + # Transpose to match input shape (batch_size, channels, height, width) + img = np.transpose(img, (2, 0, 1)).astype(np.float32) + + # Add batch dimension + img = np.expand_dims(img, axis=0) + + return img + + # Run inference + input_name = nsfw_check_model.get_inputs()[0].name + output_name = nsfw_check_model.get_outputs()[0].name + + if print_log: + log.info("Running NSFW checker (onnx)") + + results = [] + for img in images: + is_base64 = isinstance(img, str) + + input_img = base64_str_to_img(img) if is_base64 else img + + result = nsfw_check_model.run([output_name], {input_name: preprocess_image(input_img)}) + is_nsfw = [np.argmax(arr) == 1 for arr in result][0] + + if is_nsfw: + output_img = input_img.filter(ImageFilter.GaussianBlur(blur_radius)) + output_img = img_to_base64_str(output_img) if is_base64 else output_img + else: + output_img = img + + results.append(output_img) + + return results diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index e1b4d79a..29c84f22 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -34,10 +34,12 @@ TASK_TEXT_MAPPING = { "control_alpha": "ControlNet Strength", "use_vae_model": "VAE model", "sampler_name": "Sampler", + "scheduler_name": "Scheduler", "width": "Width", "height": "Height", "num_inference_steps": "Steps", "guidance_scale": "Guidance Scale", + "distilled_guidance_scale": "Distilled Guidance", "prompt_strength": "Prompt Strength", "use_lora_model": "LoRA model", "lora_alpha": "LoRA Strength", @@ -247,7 +249,7 @@ def get_printable_request( task_data_metadata.update(save_data.dict()) app_config = app.getConfig() - using_diffusers = app_config.get("use_v3_engine", True) + using_diffusers = app_config.get("backend", "ed_diffusers") in ("ed_diffusers", "webui") # Save the metadata in the order defined in TASK_TEXT_MAPPING metadata = {} diff --git a/ui/index.html b/ui/index.html index a50c995d..90b306a7 100644 --- a/ui/index.html +++ b/ui/index.html @@ -35,7 +35,13 @@

Easy Diffusion - v3.0.9 + + + v3.0.10 + v3.5.0 + + +

@@ -73,7 +79,7 @@
- +
@@ -83,7 +89,7 @@ Click to learn more about Negative Prompts (optional) - +
@@ -174,14 +180,14 @@ - + Click to learn more about Clip Skip - +
@@ -201,40 +207,92 @@ + + + + + + + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + + + + + + + +
-
- +
@@ -248,28 +306,61 @@ - Click to learn more about samplers + Click to learn more about samplers + Tip:This sampler does not work well with Flux! + + + + Tip:This scheduler does not work well with Flux! + Tip:This scheduler needs 15 steps or more +
- + + @@ -357,14 +451,14 @@
- + - +
- + - + @@ -405,7 +499,7 @@
- +
  • @@ -418,7 +512,13 @@ @@ -825,7 +925,8 @@

    This license of this software forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm,
    spread misinformation and target vulnerable groups. For the full list of restrictions please read the license.

    By using this software, you consent to the terms and conditions of the license.

    - + + diff --git a/ui/main.py b/ui/main.py index 0239c829..1c4056a3 100644 --- a/ui/main.py +++ b/ui/main.py @@ -9,6 +9,3 @@ server.init() model_manager.init() app.init_render_threads() bucket_manager.init() - -# start the browser ui -app.open_browser() diff --git a/ui/media/css/auto-save.css b/ui/media/css/auto-save.css index 94790740..021f06d5 100644 --- a/ui/media/css/auto-save.css +++ b/ui/media/css/auto-save.css @@ -79,6 +79,7 @@ } .parameters-table .fa-fire, -.parameters-table .fa-bolt { +.parameters-table .fa-bolt, +.parameters-table .fa-robot { color: #F7630C; } diff --git a/ui/media/css/main.css b/ui/media/css/main.css index 009c13d5..39b7b6e4 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -36,6 +36,15 @@ code { transform: translateY(4px); cursor: pointer; } +#engine-logo { + font-size: 8pt; + padding-left: 10pt; + color: var(--small-label-color); +} +#engine-logo a { + text-decoration: none; + /* color: var(--small-label-color); */ +} #prompt { width: 100%; height: 65pt; @@ -541,7 +550,7 @@ div.img-preview img { position: relative; background: var(--background-color4); display: flex; - padding: 12px 0 0; + padding: 6px 0 0; } .tab .icon { padding-right: 4pt; @@ -657,6 +666,15 @@ div.img-preview img { display: block; } +.gated-feature { + display: none; +} + +.warning-label { + font-size: smaller; + color: var(--status-orange); +} + .display-settings { float: right; position: relative; @@ -1459,11 +1477,6 @@ div.top-right { margin-top: 6px; } -#small_image_warning { - font-size: smaller; - color: var(--status-orange); -} - button#save-system-settings-btn { padding: 4pt 8pt; } diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 4ed28b89..16bb4aac 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -16,10 +16,12 @@ const SETTINGS_IDS_LIST = [ "clip_skip", "vae_model", "sampler_name", + "scheduler_name", "width", "height", "num_inference_steps", "guidance_scale", + "distilled_guidance_scale", "prompt_strength", "tiling", "output_format", @@ -29,6 +31,8 @@ const SETTINGS_IDS_LIST = [ "stream_image_progress", "use_face_correction", "gfpgan_model", + "codeformer_fidelity", + "codeformer_upscale_faces", "use_upscale", "upscale_amount", "latent_upscaler_steps", diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 5cb517fe..cc15cc35 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -131,6 +131,15 @@ const TASK_MAPPING = { readUI: () => parseFloat(guidanceScaleField.value), parse: (val) => parseFloat(val), }, + distilled_guidance_scale: { + name: "Distilled Guidance", + setUI: (distilled_guidance_scale) => { + distilledGuidanceScaleField.value = distilled_guidance_scale + updateDistilledGuidanceScaleSlider() + }, + readUI: () => parseFloat(distilledGuidanceScaleField.value), + parse: (val) => parseFloat(val), + }, prompt_strength: { name: "Prompt Strength", setUI: (prompt_strength) => { @@ -242,6 +251,14 @@ const TASK_MAPPING = { readUI: () => samplerField.value, parse: (val) => val, }, + scheduler_name: { + name: "Scheduler", + setUI: (scheduler_name) => { + schedulerField.value = scheduler_name + }, + readUI: () => schedulerField.value, + parse: (val) => val, + }, use_stable_diffusion_model: { name: "Stable Diffusion model", setUI: (use_stable_diffusion_model) => { @@ -590,11 +607,13 @@ const TASK_TEXT_MAPPING = { seed: "Seed", num_inference_steps: "Steps", guidance_scale: "Guidance Scale", + distilled_guidance_scale: "Distilled Guidance", prompt_strength: "Prompt Strength", use_face_correction: "Use Face Correction", use_upscale: "Use Upscaling", upscale_amount: "Upscale By", sampler_name: "Sampler", + scheduler_name: "Scheduler", negative_prompt: "Negative Prompt", use_stable_diffusion_model: "Stable Diffusion model", use_hypernetwork_model: "Hypernetwork model", diff --git a/ui/media/js/main.js b/ui/media/js/main.js index bff87b3b..e59a6ace 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -12,8 +12,16 @@ const taskConfigSetup = { seed: { value: ({ seed }) => seed, label: "Seed" }, dimensions: { value: ({ reqBody }) => `${reqBody?.width}x${reqBody?.height}`, label: "Dimensions" }, sampler_name: "Sampler", + scheduler_name: { + label: "Scheduler", + visible: ({ reqBody }) => reqBody?.scheduler_name, + }, num_inference_steps: "Inference Steps", guidance_scale: "Guidance Scale", + distilled_guidance_scale: { + label: "Distilled Guidance Scale", + visible: ({ reqBody }) => reqBody?.distilled_guidance_scale, + }, use_stable_diffusion_model: "Model", clip_skip: { label: "Clip Skip", @@ -76,6 +84,8 @@ let numOutputsParallelField = document.querySelector("#num_outputs_parallel") let numInferenceStepsField = document.querySelector("#num_inference_steps") let guidanceScaleSlider = document.querySelector("#guidance_scale_slider") let guidanceScaleField = document.querySelector("#guidance_scale") +let distilledGuidanceScaleSlider = document.querySelector("#distilled_guidance_scale_slider") +let distilledGuidanceScaleField = document.querySelector("#distilled_guidance_scale") let outputQualitySlider = document.querySelector("#output_quality_slider") let outputQualityField = document.querySelector("#output_quality") let outputQualityRow = document.querySelector("#output_quality_row") @@ -113,6 +123,8 @@ let promptStrengthSlider = document.querySelector("#prompt_strength_slider") let promptStrengthField = document.querySelector("#prompt_strength") let samplerField = document.querySelector("#sampler_name") let samplerSelectionContainer = document.querySelector("#samplerSelection") +let schedulerField = document.querySelector("#scheduler_name") +let schedulerSelectionContainer = document.querySelector("#schedulerSelection") let useFaceCorrectionField = document.querySelector("#use_face_correction") let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model"), ["gfpgan", "codeformer"], "", false) let useUpscalingField = document.querySelector("#use_upscale") @@ -981,7 +993,20 @@ function onRedoFilter(req, img, e, tools) { function onUpscaleClick(req, img, e, tools) { let path = upscaleModelField.value let scale = parseInt(upscaleAmountField.value) - let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler" + + let filterName = null + const FILTERS = ["realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"] + for (let idx in FILTERS) { + let f = FILTERS[idx] + if (path.toLowerCase().includes(f)) { + filterName = f + break + } + } + + if (!filterName) { + return + } let statusText = "Upscaling by " + scale + "x using " + filterName applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools) } @@ -1038,6 +1063,9 @@ function makeImage() { if (guidanceScaleField.value == "") { guidanceScaleField.value = guidanceScaleSlider.value / 10 } + if (distilledGuidanceScaleField.value == "") { + distilledGuidanceScaleField.value = distilledGuidanceScaleSlider.value / 10 + } if (hypernetworkStrengthField.value == "") { hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100 } @@ -1406,6 +1434,12 @@ function getCurrentUserRequest() { newTask.reqBody.control_filter_to_apply = controlImageFilterField.value } } + if (stableDiffusionModelField.value.toLowerCase().includes("flux")) { + newTask.reqBody.distilled_guidance_scale = parseFloat(distilledGuidanceScaleField.value) + } + if (schedulerSelectionContainer.style.display !== "none") { + newTask.reqBody.scheduler_name = schedulerField.value + } return newTask } @@ -1845,36 +1879,137 @@ controlImagePreview.addEventListener("load", onControlnetModelChange) controlImagePreview.addEventListener("unload", onControlnetModelChange) onControlnetModelChange() -function onControlImageFilterChange() { - let filterId = controlImageFilterField.value - if (filterId.includes("openpose")) { - controlnetModelField.value = "control_v11p_sd15_openpose" - } else if (filterId === "canny") { - controlnetModelField.value = "control_v11p_sd15_canny" - } else if (filterId === "mlsd") { - controlnetModelField.value = "control_v11p_sd15_mlsd" - } else if (filterId === "mlsd") { - controlnetModelField.value = "control_v11p_sd15_mlsd" - } else if (filterId.includes("scribble")) { - controlnetModelField.value = "control_v11p_sd15_scribble" - } else if (filterId.includes("softedge")) { - controlnetModelField.value = "control_v11p_sd15_softedge" - } else if (filterId === "normal_bae") { - controlnetModelField.value = "control_v11p_sd15_normalbae" - } else if (filterId.includes("depth")) { - controlnetModelField.value = "control_v11f1p_sd15_depth" - } else if (filterId === "lineart_anime") { - controlnetModelField.value = "control_v11p_sd15s2_lineart_anime" - } else if (filterId.includes("lineart")) { - controlnetModelField.value = "control_v11p_sd15_lineart" - } else if (filterId === "shuffle") { - controlnetModelField.value = "control_v11e_sd15_shuffle" - } else if (filterId === "segment") { - controlnetModelField.value = "control_v11p_sd15_seg" +document.addEventListener("refreshModels", function() { + onFixFaceModelChange() + onControlnetModelChange() +}) + +// tip for Flux +let sdModelField = document.querySelector("#stable_diffusion_model") +function checkGuidanceValue() { + let guidance = parseFloat(guidanceScaleField.value) + let guidanceWarning = document.querySelector("#guidanceWarning") + let guidanceWarningText = document.querySelector("#guidanceWarningText") + if (sdModelField.value.toLowerCase().includes("flux")) { + if (guidance > 1.5) { + guidanceWarningText.innerText = "Flux recommends a 'Guidance Scale' of 1" + guidanceWarning.classList.remove("displayNone") + } else { + guidanceWarning.classList.add("displayNone") + } + } else { + if (guidance < 2) { + guidanceWarningText.innerText = "A higher 'Guidance Scale' is recommended!" + guidanceWarning.classList.remove("displayNone") + } else { + guidanceWarning.classList.add("displayNone") + } } } -controlImageFilterField.addEventListener("change", onControlImageFilterChange) -onControlImageFilterChange() +sdModelField.addEventListener("change", checkGuidanceValue) +guidanceScaleField.addEventListener("change", checkGuidanceValue) +guidanceScaleSlider.addEventListener("change", checkGuidanceValue) + +function checkGuidanceScaleVisibility() { + let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container") + if (sdModelField.value.toLowerCase().includes("flux")) { + guidanceScaleContainer.classList.remove("displayNone") + } else { + guidanceScaleContainer.classList.add("displayNone") + } +} +sdModelField.addEventListener("change", checkGuidanceScaleVisibility) + +function checkFluxSampler() { + let samplerWarning = document.querySelector("#fluxSamplerWarning") + if (sdModelField.value.toLowerCase().includes("flux")) { + if (samplerField.value == "euler_a") { + samplerWarning.classList.remove("displayNone") + } else { + samplerWarning.classList.add("displayNone") + } + } else { + samplerWarning.classList.add("displayNone") + } +} + +function checkFluxScheduler() { + const badSchedulers = ["automatic", "uniform", "turbo", "align_your_steps", "align_your_steps_GITS", "align_your_steps_11", "align_your_steps_32"] + + let schedulerWarning = document.querySelector("#fluxSchedulerWarning") + if (sdModelField.value.toLowerCase().includes("flux")) { + if (badSchedulers.includes(schedulerField.value)) { + schedulerWarning.classList.remove("displayNone") + } else { + schedulerWarning.classList.add("displayNone") + } + } else { + schedulerWarning.classList.add("displayNone") + } +} + +function checkFluxSchedulerSteps() { + const problematicSchedulers = ["karras", "exponential", "polyexponential"] + + let schedulerWarning = document.querySelector("#fluxSchedulerStepsWarning") + if (sdModelField.value.toLowerCase().includes("flux") && parseInt(numInferenceStepsField.value) < 15) { + if (problematicSchedulers.includes(schedulerField.value)) { + schedulerWarning.classList.remove("displayNone") + } else { + schedulerWarning.classList.add("displayNone") + } + } else { + schedulerWarning.classList.add("displayNone") + } +} +sdModelField.addEventListener("change", checkFluxSampler) +samplerField.addEventListener("change", checkFluxSampler) + +sdModelField.addEventListener("change", checkFluxScheduler) +schedulerField.addEventListener("change", checkFluxScheduler) + +sdModelField.addEventListener("change", checkFluxSchedulerSteps) +schedulerField.addEventListener("change", checkFluxSchedulerSteps) +numInferenceStepsField.addEventListener("change", checkFluxSchedulerSteps) + +document.addEventListener("refreshModels", function() { + checkGuidanceValue() + checkGuidanceScaleVisibility() + checkFluxSampler() + checkFluxScheduler() + checkFluxSchedulerSteps() +}) + +// function onControlImageFilterChange() { +// let filterId = controlImageFilterField.value +// if (filterId.includes("openpose")) { +// controlnetModelField.value = "control_v11p_sd15_openpose" +// } else if (filterId === "canny") { +// controlnetModelField.value = "control_v11p_sd15_canny" +// } else if (filterId === "mlsd") { +// controlnetModelField.value = "control_v11p_sd15_mlsd" +// } else if (filterId === "mlsd") { +// controlnetModelField.value = "control_v11p_sd15_mlsd" +// } else if (filterId.includes("scribble")) { +// controlnetModelField.value = "control_v11p_sd15_scribble" +// } else if (filterId.includes("softedge")) { +// controlnetModelField.value = "control_v11p_sd15_softedge" +// } else if (filterId === "normal_bae") { +// controlnetModelField.value = "control_v11p_sd15_normalbae" +// } else if (filterId.includes("depth")) { +// controlnetModelField.value = "control_v11f1p_sd15_depth" +// } else if (filterId === "lineart_anime") { +// controlnetModelField.value = "control_v11p_sd15s2_lineart_anime" +// } else if (filterId.includes("lineart")) { +// controlnetModelField.value = "control_v11p_sd15_lineart" +// } else if (filterId === "shuffle") { +// controlnetModelField.value = "control_v11e_sd15_shuffle" +// } else if (filterId === "segment") { +// controlnetModelField.value = "control_v11p_sd15_seg" +// } +// } +// controlImageFilterField.addEventListener("change", onControlImageFilterChange) +// onControlImageFilterChange() upscaleModelField.disabled = !useUpscalingField.checked upscaleAmountField.disabled = !useUpscalingField.checked @@ -1973,6 +2108,27 @@ guidanceScaleSlider.addEventListener("input", updateGuidanceScale) guidanceScaleField.addEventListener("input", updateGuidanceScaleSlider) updateGuidanceScale() +/********************* Distilled Guidance **************************/ +function updateDistilledGuidanceScale() { + distilledGuidanceScaleField.value = distilledGuidanceScaleSlider.value / 10 + distilledGuidanceScaleField.dispatchEvent(new Event("change")) +} + +function updateDistilledGuidanceScaleSlider() { + if (distilledGuidanceScaleField.value < 0) { + distilledGuidanceScaleField.value = 0 + } else if (distilledGuidanceScaleField.value > 50) { + distilledGuidanceScaleField.value = 50 + } + + distilledGuidanceScaleSlider.value = distilledGuidanceScaleField.value * 10 + distilledGuidanceScaleSlider.dispatchEvent(new Event("change")) +} + +distilledGuidanceScaleSlider.addEventListener("input", updateDistilledGuidanceScale) +distilledGuidanceScaleField.addEventListener("input", updateDistilledGuidanceScaleSlider) +updateDistilledGuidanceScale() + /********************* Prompt Strength *******************/ function updatePromptStrength() { promptStrengthField.value = promptStrengthSlider.value / 100 diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 97b7a96d..65bc82d5 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -102,7 +102,7 @@ var PARAMETERS = [ type: ParameterType.custom, icon: "fa-folder-tree", label: "Models Folder", - note: "Path to the 'models' folder. Please save and refresh the page after changing this.", + note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.", saveInAppConfig: true, render: (parameter) => { return `` @@ -161,6 +161,7 @@ var PARAMETERS = [ "Low: slowest, recommended for GPUs with 3 to 4 GB memory", icon: "fa-forward", default: "balanced", + saveInAppConfig: true, options: [ { value: "balanced", label: "Balanced" }, { value: "high", label: "High" }, @@ -249,14 +250,19 @@ var PARAMETERS = [ default: false, }, { - id: "use_v3_engine", - type: ParameterType.checkbox, - label: "Use the new v3 engine (diffusers)", + id: "backend", + type: ParameterType.select, + label: "Engine to use", note: - "Use our new v3 engine, with additional features like LoRA, ControlNet, SDXL, Embeddings, Tiling and lots more! Please press Save, then restart the program after changing this.", - icon: "fa-bolt", - default: true, + "Use our new v3.5 engine (Forge), with additional features like Flux, SD3, Lycoris and lots more! Please press Save, then restart the program after changing this.", + icon: "fa-robot", saveInAppConfig: true, + default: "ed_diffusers", + options: [ + { value: "webui", label: "v3.5 (latest)" }, + { value: "ed_diffusers", label: "v3.0" }, + { value: "ed_classic", label: "v2.0" }, + ], }, { id: "cloudflare", @@ -432,6 +438,7 @@ let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") let testDiffusers = document.querySelector("#use_v3_engine") +let backendEngine = document.querySelector("#backend") let profileNameField = document.querySelector("#profileName") let modelsDirField = document.querySelector("#models_dir") @@ -454,6 +461,23 @@ async function changeAppConfig(configDelta) { } } +function getDefaultDisplay(element) { + const tag = element.tagName.toLowerCase(); + const defaultDisplays = { + div: 'block', + span: 'inline', + p: 'block', + tr: 'table-row', + table: 'table', + li: 'list-item', + ul: 'block', + ol: 'block', + button: 'inline', + // Add more if needed + }; + return defaultDisplays[tag] || 'block'; // Default to 'block' if not listed +} + async function getAppConfig() { try { let res = await fetch("/get/app_config") @@ -478,14 +502,16 @@ async function getAppConfig() { modelsDirField.value = config.models_dir let testDiffusersEnabled = true - if (config.use_v3_engine === false) { + if (config.backend === "ed_classic") { testDiffusersEnabled = false } testDiffusers.checked = testDiffusersEnabled + backendEngine.value = config.backend document.querySelector("#test_diffusers").checked = testDiffusers.checked // don't break plugins + document.querySelector("#use_v3_engine").checked = testDiffusers.checked // don't break plugins if (config.config_on_startup) { - if (config.config_on_startup?.use_v3_engine) { + if (config.config_on_startup?.backend !== "ed_classic") { document.body.classList.add("diffusers-enabled-on-startup") document.body.classList.remove("diffusers-disabled-on-startup") } else { @@ -494,37 +520,27 @@ async function getAppConfig() { } } - if (!testDiffusersEnabled) { - document.querySelector("#lora_model_container").style.display = "none" - document.querySelector("#tiling_container").style.display = "none" - document.querySelector("#controlnet_model_container").style.display = "none" - document.querySelector("#hypernetwork_model_container").style.display = "" - document.querySelector("#hypernetwork_strength_container").style.display = "" - document.querySelector("#negative-embeddings-button").style.display = "none" - - document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => { - option.style.display = "none" - }) + if (config.backend === "ed_classic") { IMAGE_STEP_SIZE = 64 - customWidthField.step = IMAGE_STEP_SIZE - customHeightField.step = IMAGE_STEP_SIZE } else { - document.querySelector("#lora_model_container").style.display = "" - document.querySelector("#tiling_container").style.display = "" - document.querySelector("#controlnet_model_container").style.display = "" - document.querySelector("#hypernetwork_model_container").style.display = "none" - document.querySelector("#hypernetwork_strength_container").style.display = "none" - - document.querySelectorAll("#sampler_name option.k_diffusion-only").forEach((option) => { - option.style.display = "none" - }) - document.querySelector("#clip_skip_config").classList.remove("displayNone") - document.querySelector("#embeddings-button").classList.remove("displayNone") IMAGE_STEP_SIZE = 8 - customWidthField.step = IMAGE_STEP_SIZE - customHeightField.step = IMAGE_STEP_SIZE } + customWidthField.step = IMAGE_STEP_SIZE + customHeightField.step = IMAGE_STEP_SIZE + + const currentBackendKey = "backend_" + config.backend + + document.querySelectorAll('.gated-feature').forEach((element) => { + const featureKeys = element.getAttribute('data-feature-keys').split(' ') + + if (featureKeys.includes(currentBackendKey)) { + element.style.display = getDefaultDisplay(element) + } else { + element.style.display = 'none' + } + }); + if (config.force_save_metadata) { metadataOutputFormatField.value = config.force_save_metadata } @@ -749,6 +765,11 @@ async function getSystemInfo() { metadataOutputFormatField.disabled = !saveToDiskField.checked } setDiskPath(res["default_output_dir"], force) + + // backend info + if (res["backend_url"]) { + document.querySelector("#backend-url").setAttribute("href", res["backend_url"]) + } } catch (e) { console.log("error fetching devices", e) }