Merge pull request #1847 from easydiffusion/forge

ED 3.5 - Forge as a new backend
This commit is contained in:
cmdr2 2024-10-12 12:52:16 +05:30 committed by GitHub
commit d8c3d7cf92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 2469 additions and 458 deletions

View File

@ -1,5 +1,21 @@
# What's new? # What's new?
## v3.5 (preview)
### Major Changes
- **Flux** - full support for the Flux model, including quantized bnb and nf4 models.
- **LyCORIS** - including `LoCon`, `Hada`, `IA3` and `Lokr`.
- **11 new samplers** - `DDIM CFG++`, `DPM Fast`, `DPM++ 2m SDE Heun`, `DPM++ 3M SDE`, `Restart`, `Heun PP2`, `IPNDM`, `IPNDM_V`, `LCM`, `[Forge] Flux Realistic`, `[Forge] Flux Realistic (Slow)`.
- **15 new schedulers** - `Uniform`, `Karras`, `Exponential`, `Polyexponential`, `SGM Uniform`, `KL Optimal`, `Align Your Steps`, `Normal`, `DDIM`, `Beta`, `Turbo`, `Align Your Steps GITS`, `Align Your Steps 11`, `Align Your Steps 32`.
- **42 new Controlnet filters, and support for lots of new ControlNet models** (including QR ControlNets).
- **5 upscalers** - `SwinIR`, `ScuNET`, `Nearest`, `Lanczos`, `ESRGAN`.
- **Faster than v3.0**
- **Major rewrite of the code** - We've switched to `Forge WebUI` under the hood, which brings a lot of new features, faster image generation, and support for all the extensions in the Forge/Automatic1111 community. This allows Easy Diffusion to stay up-to-date with the latest features, and focus on making the UI and installation experience even easier.
v3.5 is currently an optional upgrade, and you can switch between the v3.0 (diffusers) engine and the v3.5 (webui) engine using the `Settings` tab in the UI.
### Detailed changelog
* 3.5.0 - 11 Oct 2024 - **Preview release** of the new v3.5 engine, powered by Forge WebUI (a fork of Automatic1111). This enables Flux, SD3, LyCORIS and lots of new features, while using the same familiar Easy Diffusion interface.
## v3.0 ## v3.0
### Major Changes ### Major Changes
- **ControlNet** - Full support for ControlNet, with native integration of the common ControlNet models. Just select a control image, then choose the ControlNet filter/model and run. No additional configuration or download necessary. Supports custom ControlNets as well. - **ControlNet** - Full support for ControlNet, with native integration of the common ControlNet models. Just select a control image, then choose the ControlNet filter/model and run. No additional configuration or download necessary. Supports custom ControlNets as well.
@ -17,6 +33,7 @@
- **Major rewrite of the code** - We've switched to using diffusers under-the-hood, which allows us to release new features faster, and focus on making the UI and installer even easier to use. - **Major rewrite of the code** - We've switched to using diffusers under-the-hood, which allows us to release new features faster, and focus on making the UI and installer even easier to use.
### Detailed changelog ### Detailed changelog
* 3.0.10 - 11 Oct 2024 - **Major Update** - An option to upgrade to v3.5, which enables Flux, Stable Diffusion 3, LyCORIS models and lots more.
* 3.0.9 - 28 May 2024 - Slider for controlling the strength of controlnets. * 3.0.9 - 28 May 2024 - Slider for controlling the strength of controlnets.
* 3.0.8 - 27 May 2024 - SDXL ControlNets for Img2Img and Inpainting. * 3.0.8 - 27 May 2024 - SDXL ControlNets for Img2Img and Inpainting.
* 3.0.7 - 11 Dec 2023 - Setting to enable/disable VAE tiling (in the Image Settings panel). Sometimes VAE tiling reduces the quality of the image, so this setting will help control that. * 3.0.7 - 11 Dec 2023 - Setting to enable/disable VAE tiling (in the Image Settings panel). Sometimes VAE tiling reduces the quality of the image, so this setting will help control that.

View File

@ -34,6 +34,7 @@ modules_to_check = {
"sqlalchemy": "2.0.19", "sqlalchemy": "2.0.19",
"python-multipart": "0.0.6", "python-multipart": "0.0.6",
# "xformers": "0.0.16", # "xformers": "0.0.16",
"onnxruntime": "1.19.2",
} }
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"] modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
@ -297,7 +298,7 @@ Thanks!"""
def get_config(): def get_config():
config_directory = os.path.dirname(__file__) # this will be "scripts" config_directory = os.path.dirname(__file__) # this will be "scripts"
config_yaml = os.path.join(config_directory, "..", "config.yaml") config_yaml = os.path.abspath(os.path.join(config_directory, "..", "config.yaml"))
config_json = os.path.join(config_directory, "config.json") config_json = os.path.join(config_directory, "config.json")
config = None config = None

View File

@ -6,6 +6,7 @@ import shutil
# The config file is in the same directory as this script # The config file is in the same directory as this script
config_directory = os.path.dirname(__file__) config_directory = os.path.dirname(__file__)
config_yaml = os.path.join(config_directory, "..", "config.yaml") config_yaml = os.path.join(config_directory, "..", "config.yaml")
config_yaml = os.path.abspath(config_yaml)
config_json = os.path.join(config_directory, "config.json") config_json = os.path.join(config_directory, "config.json")
parser = argparse.ArgumentParser(description='Get values from config file') parser = argparse.ArgumentParser(description='Get values from config file')

View File

@ -71,6 +71,7 @@ if "%update_branch%"=="" (
@copy sd-ui-files\scripts\check_modules.py scripts\ /Y @copy sd-ui-files\scripts\check_modules.py scripts\ /Y
@copy sd-ui-files\scripts\get_config.py scripts\ /Y @copy sd-ui-files\scripts\get_config.py scripts\ /Y
@copy sd-ui-files\scripts\config.yaml.sample scripts\ /Y @copy sd-ui-files\scripts\config.yaml.sample scripts\ /Y
@copy sd-ui-files\scripts\webui_console.py scripts\ /Y
@copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y @copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y
@copy "sd-ui-files\scripts\Developer Console.cmd" . /Y @copy "sd-ui-files\scripts\Developer Console.cmd" . /Y

View File

@ -54,6 +54,7 @@ cp sd-ui-files/scripts/bootstrap.sh scripts/
cp sd-ui-files/scripts/check_modules.py scripts/ cp sd-ui-files/scripts/check_modules.py scripts/
cp sd-ui-files/scripts/get_config.py scripts/ cp sd-ui-files/scripts/get_config.py scripts/
cp sd-ui-files/scripts/config.yaml.sample scripts/ cp sd-ui-files/scripts/config.yaml.sample scripts/
cp sd-ui-files/scripts/webui_console.py scripts/
cp sd-ui-files/scripts/start.sh . cp sd-ui-files/scripts/start.sh .
cp sd-ui-files/scripts/developer_console.sh . cp sd-ui-files/scripts/developer_console.sh .
cp sd-ui-files/scripts/functions.sh scripts/ cp sd-ui-files/scripts/functions.sh scripts/

View File

@ -7,6 +7,7 @@
@copy sd-ui-files\scripts\check_modules.py scripts\ /Y @copy sd-ui-files\scripts\check_modules.py scripts\ /Y
@copy sd-ui-files\scripts\get_config.py scripts\ /Y @copy sd-ui-files\scripts\get_config.py scripts\ /Y
@copy sd-ui-files\scripts\config.yaml.sample scripts\ /Y @copy sd-ui-files\scripts\config.yaml.sample scripts\ /Y
@copy sd-ui-files\scripts\webui_console.py scripts\ /Y
if exist "%cd%\profile" ( if exist "%cd%\profile" (
set HF_HOME=%cd%\profile\.cache\huggingface set HF_HOME=%cd%\profile\.cache\huggingface

View File

@ -6,16 +6,20 @@ cp sd-ui-files/scripts/bootstrap.sh scripts/
cp sd-ui-files/scripts/check_modules.py scripts/ cp sd-ui-files/scripts/check_modules.py scripts/
cp sd-ui-files/scripts/get_config.py scripts/ cp sd-ui-files/scripts/get_config.py scripts/
cp sd-ui-files/scripts/config.yaml.sample scripts/ cp sd-ui-files/scripts/config.yaml.sample scripts/
cp sd-ui-files/scripts/webui_console.py scripts/
source ./scripts/functions.sh source ./scripts/functions.sh
# activate the installer env # activate the installer env
CONDA_BASEPATH=$(conda info --base) export CONDA_BASEPATH=$(conda info --base)
source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # avoids the 'shell not initialized' error source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # avoids the 'shell not initialized' error
conda activate || fail "Failed to activate conda" conda activate || fail "Failed to activate conda"
# hack to fix conda 4.14 on older installations
cp $CONDA_BASEPATH/condabin/conda $CONDA_BASEPATH/bin/conda
# remove the old version of the dev console script, if it's still present # remove the old version of the dev console script, if it's still present
if [ -e "open_dev_console.sh" ]; then if [ -e "open_dev_console.sh" ]; then
rm "open_dev_console.sh" rm "open_dev_console.sh"

101
scripts/webui_console.py Normal file
View File

@ -0,0 +1,101 @@
import os
import platform
import subprocess
def configure_env(dir):
env_entries = {
"PATH": [
f"{dir}",
f"{dir}/bin",
f"{dir}/Library/bin",
f"{dir}/Scripts",
f"{dir}/usr/bin",
],
"PYTHONPATH": [
f"{dir}",
f"{dir}/lib/site-packages",
f"{dir}/lib/python3.10/site-packages",
],
"PYTHONHOME": [],
"PY_LIBS": [
f"{dir}/Scripts/Lib",
f"{dir}/Scripts/Lib/site-packages",
f"{dir}/lib",
f"{dir}/lib/python3.10/site-packages",
],
"PY_PIP": [f"{dir}/Scripts", f"{dir}/bin"],
}
if platform.system() == "Windows":
env_entries["PATH"].append("C:/Windows/System32")
env_entries["PATH"].append("C:/Windows/System32/wbem")
env_entries["PYTHONNOUSERSITE"] = ["1"]
env_entries["PYTHON"] = [f"{dir}/python"]
env_entries["GIT"] = [f"{dir}/Library/bin/git"]
else:
env_entries["PATH"].append("/bin")
env_entries["PATH"].append("/usr/bin")
env_entries["PATH"].append("/usr/sbin")
env_entries["PYTHONNOUSERSITE"] = ["y"]
env_entries["PYTHON"] = [f"{dir}/bin/python"]
env_entries["GIT"] = [f"{dir}/bin/git"]
env = {}
for key, paths in env_entries.items():
paths = [p.replace("/", os.path.sep) for p in paths]
paths = os.pathsep.join(paths)
os.environ[key] = paths
return env
def print_env_info():
which_cmd = "where" if platform.system() == "Windows" else "which"
python = "python"
def locate_python():
nonlocal python
python = subprocess.getoutput(f"{which_cmd} python")
python = python.split("\n")
python = python[0].strip()
print("python: ", python)
locate_python()
def run(cmd):
with subprocess.Popen(cmd) as p:
p.wait()
run([which_cmd, "git"])
run(["git", "--version"])
run([which_cmd, "python"])
run([python, "--version"])
print(f"PATH={os.environ['PATH']}")
if platform.system() == "Windows":
print(f"COMSPEC={os.environ['COMSPEC']}")
print("")
run("wmic path win32_VideoController get name,AdapterRAM,DriverDate,DriverVersion".split(" "))
print(f"PYTHONPATH={os.environ['PYTHONPATH']}")
print("")
def open_dev_shell():
if platform.system() == "Windows":
subprocess.Popen("cmd").communicate()
else:
subprocess.Popen("bash").communicate()
if __name__ == "__main__":
env_dir = os.path.abspath(os.path.join("webui", "system"))
configure_env(env_dir)
print_env_info()
open_dev_shell()

View File

@ -11,7 +11,7 @@ from ruamel.yaml import YAML
import urllib import urllib
import warnings import warnings
from easydiffusion import task_manager from easydiffusion import task_manager, backend_manager
from easydiffusion.utils import log from easydiffusion.utils import log
from rich.logging import RichHandler from rich.logging import RichHandler
from rich.console import Console from rich.console import Console
@ -36,10 +36,10 @@ ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, ".."))
SD_UI_DIR = os.getenv("SD_UI_PATH", None) SD_UI_DIR = os.getenv("SD_UI_PATH", None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) CONFIG_DIR = os.path.abspath(os.path.join(ROOT_DIR, "scripts"))
BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket")) BUCKET_DIR = os.path.abspath(os.path.join(ROOT_DIR, "bucket"))
USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) USER_PLUGINS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "plugins"))
CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins"))
USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui") USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui")
@ -60,7 +60,7 @@ APP_CONFIG_DEFAULTS = {
"ui": { "ui": {
"open_browser_on_start": True, "open_browser_on_start": True,
}, },
"use_v3_engine": True, "backend": "ed_diffusers",
} }
IMAGE_EXTENSIONS = [ IMAGE_EXTENSIONS = [
@ -77,7 +77,7 @@ IMAGE_EXTENSIONS = [
".avif", ".avif",
".svg", ".svg",
] ]
CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers")) CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "modifiers"))
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [ CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [
".portrait", ".portrait",
"_portrait", "_portrait",
@ -91,7 +91,7 @@ CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [
"-landscape", "-landscape",
] ]
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) MODELS_DIR = os.path.abspath(os.path.join(ROOT_DIR, "models"))
def init(): def init():
@ -105,9 +105,11 @@ def init():
config = getConfig() config = getConfig()
config_models_dir = config.get("models_dir", None) config_models_dir = config.get("models_dir", None)
if (config_models_dir is not None and config_models_dir != ""): if config_models_dir is not None and config_models_dir != "":
MODELS_DIR = config_models_dir MODELS_DIR = config_models_dir
backend_manager.start_backend()
def init_render_threads(): def init_render_threads():
load_server_plugins() load_server_plugins()
@ -117,6 +119,7 @@ def init_render_threads():
def getConfig(default_val=APP_CONFIG_DEFAULTS): def getConfig(default_val=APP_CONFIG_DEFAULTS):
config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml") config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml")
config_yaml_path = os.path.abspath(config_yaml_path)
# migrate the old config yaml location # migrate the old config yaml location
config_legacy_yaml = os.path.join(CONFIG_DIR, "config.yaml") config_legacy_yaml = os.path.join(CONFIG_DIR, "config.yaml")
@ -124,9 +127,9 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
shutil.move(config_legacy_yaml, config_yaml_path) shutil.move(config_legacy_yaml, config_yaml_path)
def set_config_on_startup(config: dict): def set_config_on_startup(config: dict):
if getConfig.__use_v3_engine_on_startup is None: if getConfig.__use_backend_on_startup is None:
getConfig.__use_v3_engine_on_startup = config.get("use_v3_engine", True) getConfig.__use_backend_on_startup = config.get("backend", "ed_diffusers")
config["config_on_startup"] = {"use_v3_engine": getConfig.__use_v3_engine_on_startup} config["config_on_startup"] = {"backend": getConfig.__use_backend_on_startup}
if os.path.isfile(config_yaml_path): if os.path.isfile(config_yaml_path):
try: try:
@ -144,6 +147,15 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
else: else:
config["net"]["listen_to_network"] = True config["net"]["listen_to_network"] = True
if "backend" not in config:
if "use_v3_engine" in config:
config["backend"] = "ed_diffusers" if config["use_v3_engine"] else "ed_classic"
else:
config["backend"] = "ed_diffusers"
# this default will need to be smarter when WebUI becomes the main backend, but needs to maintain backwards
# compatibility with existing ED 3.0 installations that haven't opted into the WebUI backend, and haven't
# set a "use_v3_engine" flag in their config
set_config_on_startup(config) set_config_on_startup(config)
return config return config
@ -174,7 +186,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
return default_val return default_val
getConfig.__use_v3_engine_on_startup = None getConfig.__use_backend_on_startup = None
def setConfig(config): def setConfig(config):
@ -307,11 +319,14 @@ def getIPConfig():
def open_browser(): def open_browser():
from easydiffusion.backend_manager import backend
config = getConfig() config = getConfig()
ui = config.get("ui", {}) ui = config.get("ui", {})
net = config.get("net", {}) net = config.get("net", {})
port = net.get("listen_port", 9000) port = net.get("listen_port", 9000)
if backend.is_installed():
if ui.get("open_browser_on_start", True): if ui.get("open_browser_on_start", True):
import webbrowser import webbrowser
@ -329,6 +344,18 @@ def open_browser():
style="bold yellow on blue", style="bold yellow on blue",
) )
) )
else:
backend_name = config["backend"]
Console().print(
Panel(
"\n"
+ f"[white]Backend: {backend_name} is still installing..\n\n"
+ "A new browser tab will open automatically after it finishes.\n"
+ f"If it does not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n",
title=f"Backend engine is installing",
style="bold yellow on blue",
)
)
def fail_and_die(fail_type: str, data: str): def fail_and_die(fail_type: str, data: str):

View File

@ -0,0 +1,105 @@
import os
import ast
import sys
import importlib.util
import traceback
from easydiffusion.utils import log
backend = None
curr_backend_name = None
def is_valid_backend(file_path):
with open(file_path, "r", encoding="utf-8") as file:
node = ast.parse(file.read())
# Check for presence of a dictionary named 'ed_info'
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name) and target.id == "ed_info":
return True
return False
def find_valid_backends(root_dir) -> dict:
backends_path = os.path.join(root_dir, "backends")
valid_backends = {}
if not os.path.exists(backends_path):
return valid_backends
for item in os.listdir(backends_path):
item_path = os.path.join(backends_path, item)
if os.path.isdir(item_path):
init_file = os.path.join(item_path, "__init__.py")
if os.path.exists(init_file) and is_valid_backend(init_file):
valid_backends[item] = item_path
elif item.endswith(".py"):
if is_valid_backend(item_path):
backend_name = os.path.splitext(item)[0] # strip the .py extension
valid_backends[backend_name] = item_path
return valid_backends
def load_backend_module(backend_name, backend_dict):
if backend_name not in backend_dict:
raise ValueError(f"Backend '{backend_name}' not found.")
module_path = backend_dict[backend_name]
mod_dir = os.path.dirname(module_path)
sys.path.insert(0, mod_dir)
# If it's a package (directory), add its parent directory to sys.path
if os.path.isdir(module_path):
module_path = os.path.join(module_path, "__init__.py")
spec = importlib.util.spec_from_file_location(backend_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if mod_dir in sys.path:
sys.path.remove(mod_dir)
log.info(f"Loaded backend: {module}")
return module
def start_backend():
global backend, curr_backend_name
from easydiffusion.app import getConfig, ROOT_DIR
curr_dir = os.path.dirname(__file__)
backends = find_valid_backends(curr_dir)
plugin_backends = find_valid_backends(ROOT_DIR)
backends.update(plugin_backends)
config = getConfig()
backend_name = config["backend"]
if backend_name not in backends:
raise RuntimeError(
f"Couldn't find the backend configured in config.yaml: {backend_name}. Please check the name!"
)
if backend is not None and backend_name != curr_backend_name:
try:
backend.stop_backend()
except:
log.exception(traceback.format_exc())
log.info(f"Loading backend: {backend_name}")
backend = load_backend_module(backend_name, backends)
try:
backend.start_backend()
except:
log.exception(traceback.format_exc())

View File

@ -0,0 +1,28 @@
from sdkit_common import (
start_backend,
stop_backend,
install_backend,
uninstall_backend,
is_installed,
create_sdkit_context,
ping,
load_model,
unload_model,
set_options,
generate_images,
filter_images,
get_url,
stop_rendering,
refresh_models,
list_controlnet_filters,
)
ed_info = {
"name": "Classic backend for Easy Diffusion v2",
"version": (1, 0, 0),
"type": "backend",
}
def create_context():
return create_sdkit_context(use_diffusers=False)

View File

@ -0,0 +1,28 @@
from sdkit_common import (
start_backend,
stop_backend,
install_backend,
uninstall_backend,
is_installed,
create_sdkit_context,
ping,
load_model,
unload_model,
set_options,
generate_images,
filter_images,
get_url,
stop_rendering,
refresh_models,
list_controlnet_filters,
)
ed_info = {
"name": "Diffusers Backend for Easy Diffusion v3",
"version": (1, 0, 0),
"type": "backend",
}
def create_context():
return create_sdkit_context(use_diffusers=True)

View File

@ -0,0 +1,246 @@
from sdkit import Context
from easydiffusion.types import UserInitiatedStop
from sdkit.utils import (
diffusers_latent_samples_to_images,
gc,
img_to_base64_str,
latent_samples_to_images,
)
opts = {}
def install_backend():
pass
def start_backend():
print("Started sdkit backend")
def stop_backend():
pass
def uninstall_backend():
pass
def is_installed():
return True
def create_sdkit_context(use_diffusers):
c = Context()
c.test_diffusers = use_diffusers
return c
def ping(timeout=1):
return True
def load_model(context, model_type, **kwargs):
from sdkit.models import load_model
load_model(context, model_type, **kwargs)
def unload_model(context, model_type, **kwargs):
from sdkit.models import unload_model
unload_model(context, model_type, **kwargs)
def set_options(context, **kwargs):
if "vae_tiling" in kwargs and context.test_diffusers:
pipe = context.models["stable-diffusion"]["default"]
vae_tiling = kwargs["vae_tiling"]
if vae_tiling:
if hasattr(pipe, "enable_vae_tiling"):
pipe.enable_vae_tiling()
else:
if hasattr(pipe, "disable_vae_tiling"):
pipe.disable_vae_tiling()
for key in (
"output_format",
"output_quality",
"output_lossless",
"stream_image_progress",
"stream_image_progress_interval",
):
if key in kwargs:
opts[key] = kwargs[key]
def generate_images(
context: Context,
callback=None,
controlnet_filter=None,
distilled_guidance_scale: float = 3.5,
scheduler_name: str = "simple",
output_type="pil",
**req,
):
from sdkit.generate import generate_images
if req["init_image"] is not None and not context.test_diffusers:
req["sampler_name"] = "ddim"
gc(context)
context.stop_processing = False
if req["control_image"] and controlnet_filter:
controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter)
req["control_image"] = filter_images(context, req["control_image"], controlnet_filter)[0]
callback = make_step_callback(context, callback)
try:
images = generate_images(context, callback=callback, **req)
except UserInitiatedStop:
images = []
if context.partial_x_samples is not None:
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
else:
images = latent_samples_to_images(context, context.partial_x_samples)
finally:
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
if not context.test_diffusers:
del context.partial_x_samples
context.partial_x_samples = None
gc(context)
if output_type == "base64":
output_format = opts.get("output_format", "jpeg")
output_quality = opts.get("output_quality", 75)
output_lossless = opts.get("output_lossless", False)
images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images]
return images
def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"):
gc(context)
if "nsfw_checker" in filters:
filters.remove("nsfw_checker") # handled by ED directly
if len(filters) == 0:
return images
images = _filter_images(context, images, filters, filter_params)
if input_type == "base64":
output_format = opts.get("output_format", "jpg")
output_quality = opts.get("output_quality", 75)
output_lossless = opts.get("output_lossless", False)
images = [img_to_base64_str(img, output_format, output_quality, output_lossless) for img in images]
return images
def _filter_images(context, images, filters, filter_params={}):
from sdkit.filter import apply_filters
filters = filters if isinstance(filters, list) else [filters]
filters = convert_ED_controlnet_filter_name(filters)
for filter_name in filters:
params = filter_params.get(filter_name, {})
previous_state = before_filter(context, filter_name, params)
try:
images = apply_filters(context, filter_name, images, **params)
finally:
after_filter(context, filter_name, params, previous_state)
return images
def before_filter(context, filter_name, filter_params):
if filter_name == "codeformer":
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
upscale_faces = filter_params.get("upscale_faces", False)
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths.get("realesrgan")
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
return prev_realesrgan_path
def after_filter(context, filter_name, filter_params, previous_state):
if filter_name == "codeformer":
prev_realesrgan_path = previous_state
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
def get_url():
pass
def stop_rendering(context):
context.stop_processing = True
def refresh_models():
pass
def list_controlnet_filters():
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
return cn_filters
def make_step_callback(context, callback):
def on_step(x_samples, i, *args):
stream_image_progress = opts.get("stream_image_progress", False)
stream_image_progress_interval = opts.get("stream_image_progress_interval", 3)
if context.test_diffusers:
context.partial_x_samples = (x_samples, args[0])
else:
context.partial_x_samples = x_samples
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
else:
images = latent_samples_to_images(context, context.partial_x_samples)
else:
images = None
if callback:
callback(images, i, *args)
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_step
def convert_ED_controlnet_filter_name(filter):
def cn(n):
if n.startswith("controlnet_"):
return n[len("controlnet_") :]
return n
if isinstance(filter, list):
return [cn(f) for f in filter]
return cn(filter)

View File

@ -0,0 +1,450 @@
import os
import platform
import subprocess
import threading
from threading import local
import psutil
import time
import shutil
from easydiffusion.app import ROOT_DIR, getConfig
from easydiffusion.model_manager import get_model_dirs
from easydiffusion.utils import log
from . import impl
from .impl import (
ping,
load_model,
unload_model,
set_options,
generate_images,
filter_images,
get_url,
stop_rendering,
refresh_models,
list_controlnet_filters,
)
ed_info = {
"name": "WebUI backend for Easy Diffusion",
"version": (1, 0, 0),
"type": "backend",
}
WEBUI_REPO = "https://github.com/lllyasviel/stable-diffusion-webui-forge.git"
WEBUI_COMMIT = "f4d5e8cac16a42fa939e78a0956b4c30e2b47bb5"
BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui"))
SYSTEM_DIR = os.path.join(BACKEND_DIR, "system")
WEBUI_DIR = os.path.join(BACKEND_DIR, "webui")
OS_NAME = platform.system()
MODELS_TO_OVERRIDE = {
"stable-diffusion": "--ckpt-dir",
"vae": "--vae-dir",
"hypernetwork": "--hypernetwork-dir",
"gfpgan": "--gfpgan-models-path",
"realesrgan": "--realesrgan-models-path",
"lora": "--lora-dir",
"codeformer": "--codeformer-models-path",
"embeddings": "--embeddings-dir",
"controlnet": "--controlnet-dir",
}
backend_process = None
conda = "conda"
def locate_conda():
global conda
which = "where" if OS_NAME == "Windows" else "which"
conda = subprocess.getoutput(f"{which} conda")
conda = conda.split("\n")
conda = conda[0].strip()
print("conda: ", conda)
locate_conda()
def install_backend():
print("Installing the WebUI backend..")
# create the conda env
run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR)
print("Installing packages..")
# install python 3.10 and git in the conda env
run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR)
# print info
run_in_conda(["git", "--version"], cwd=ROOT_DIR)
run_in_conda(["python", "--version"], cwd=ROOT_DIR)
# clone webui
run_in_conda(["git", "clone", WEBUI_REPO, WEBUI_DIR], cwd=ROOT_DIR)
# install cpu-only torch if the PC doesn't have a graphics card (for Windows and Linux).
# this avoids WebUI installing a CUDA version and trying to activate it
if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card():
run_in_conda(["python", "-m", "pip", "install", "torch", "torchvision"], cwd=WEBUI_DIR)
def start_backend():
config = getConfig()
backend_config = config.get("backend_config", {})
if not os.path.exists(BACKEND_DIR):
install_backend()
was_still_installing = not is_installed()
if backend_config.get("auto_update", True):
run_in_conda(["git", "add", "-A", "."], cwd=WEBUI_DIR)
run_in_conda(["git", "stash"], cwd=WEBUI_DIR)
run_in_conda(["git", "reset", "--hard"], cwd=WEBUI_DIR)
run_in_conda(["git", "fetch"], cwd=WEBUI_DIR)
run_in_conda(["git", "-c", "advice.detachedHead=false", "checkout", WEBUI_COMMIT], cwd=WEBUI_DIR)
# hack to prevent webui-macos-env.sh from overwriting the COMMANDLINE_ARGS env variable
mac_webui_file = os.path.join(WEBUI_DIR, "webui-macos-env.sh")
if os.path.exists(mac_webui_file):
os.remove(mac_webui_file)
impl.WEBUI_HOST = backend_config.get("host", "localhost")
impl.WEBUI_PORT = backend_config.get("port", "7860")
env = dict(os.environ)
env.update(get_env())
def restart_if_webui_dies_after_starting():
has_started = False
while True:
try:
impl.ping(timeout=1)
is_first_start = not has_started
has_started = True
if was_still_installing and is_first_start:
ui = config.get("ui", {})
net = config.get("net", {})
port = net.get("listen_port", 9000)
if ui.get("open_browser_on_start", True):
import webbrowser
log.info("Opening browser..")
webbrowser.open(f"http://localhost:{port}")
except (TimeoutError, ConnectionError):
if has_started: # process probably died
print("######################## WebUI probably died. Restarting...")
stop_backend()
backend_thread = threading.Thread(target=target)
backend_thread.start()
break
except Exception:
import traceback
log.exception(traceback.format_exc())
time.sleep(1)
def target():
global backend_process
cmd = "webui.bat" if OS_NAME == "Windows" else "./webui.sh"
print("starting", cmd, WEBUI_DIR)
backend_process = run_in_conda([cmd], cwd=WEBUI_DIR, env=env, wait=False, output_prefix="[WebUI] ")
restart_if_dead_thread = threading.Thread(target=restart_if_webui_dies_after_starting)
restart_if_dead_thread.start()
backend_process.wait()
backend_thread = threading.Thread(target=target)
backend_thread.start()
start_proxy()
def start_proxy():
# proxy
from easydiffusion.server import server_api
from fastapi import FastAPI, Request
from fastapi.responses import Response
import json
URI_PREFIX = "/webui"
webui_proxy = FastAPI(root_path=f"{URI_PREFIX}", docs_url="/swagger")
@webui_proxy.get("{uri:path}")
def proxy_get(uri: str, req: Request):
if uri == "/openapi-proxy.json":
uri = "/openapi.json"
res = impl.webui_get(uri, headers=req.headers)
content = res.content
headers = dict(res.headers)
if uri == "/docs":
content = res.text.replace("url: '/openapi.json'", f"url: '{URI_PREFIX}/openapi-proxy.json'")
elif uri == "/openapi.json":
content = res.json()
content["paths"] = {f"{URI_PREFIX}{k}": v for k, v in content["paths"].items()}
content = json.dumps(content)
if isinstance(content, str):
content = bytes(content, encoding="utf-8")
headers["content-length"] = str(len(content))
# Return the same response back to the client
return Response(content=content, status_code=res.status_code, headers=headers)
@webui_proxy.post("{uri:path}")
async def proxy_post(uri: str, req: Request):
body = await req.body()
res = impl.webui_post(uri, data=body, headers=req.headers)
# Return the same response back to the client
return Response(content=res.content, status_code=res.status_code, headers=dict(res.headers))
server_api.mount(f"{URI_PREFIX}", webui_proxy)
def stop_backend():
global backend_process
if backend_process:
try:
kill(backend_process.pid)
except:
pass
backend_process = None
def uninstall_backend():
shutil.rmtree(BACKEND_DIR)
def is_installed():
if not os.path.exists(BACKEND_DIR) or not os.path.exists(SYSTEM_DIR) or not os.path.exists(WEBUI_DIR):
return True
env = dict(os.environ)
env.update(get_env())
try:
out = check_output_in_conda(["python", "-m", "pip", "show", "torch"], env=env)
return "Version" in out.decode()
except subprocess.CalledProcessError:
pass
return False
def read_output(pipe, prefix=""):
while True:
output = pipe.readline()
if output:
print(f"{prefix}{output.decode('utf-8')}", end="")
else:
break # Pipe is closed, subprocess has likely exited
def run(cmds: list, cwd=None, env=None, stream_output=True, wait=True, output_prefix=""):
p = subprocess.Popen(cmds, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
if stream_output:
output_thread = threading.Thread(target=read_output, args=(p.stdout, output_prefix))
output_thread.start()
if wait:
p.wait()
return p
def run_in_conda(cmds: list, *args, **kwargs):
cmds = [conda, "run", "--no-capture-output", "--prefix", SYSTEM_DIR] + cmds
return run(cmds, *args, **kwargs)
def check_output_in_conda(cmds: list, cwd=None, env=None):
cmds = [conda, "run", "--no-capture-output", "--prefix", SYSTEM_DIR] + cmds
return subprocess.check_output(cmds, cwd=cwd, env=env, stderr=subprocess.PIPE)
def create_context():
context = local()
# temp hack, throws an attribute not found error otherwise
context.device = "cuda:0"
context.half_precision = True
context.vram_usage_level = None
context.models = {}
context.model_paths = {}
context.model_configs = {}
context.device_name = None
context.vram_optimizations = set()
context.vram_usage_level = "balanced"
context.test_diffusers = False
context.enable_codeformer = False
return context
def get_env():
dir = os.path.abspath(SYSTEM_DIR)
if not os.path.exists(dir):
raise RuntimeError("The system folder is missing!")
config = getConfig()
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
model_path_args = get_model_path_args()
env_entries = {
"PATH": [
f"{dir}",
f"{dir}/bin",
f"{dir}/Library/bin",
f"{dir}/Scripts",
f"{dir}/usr/bin",
],
"PYTHONPATH": [
f"{dir}",
f"{dir}/lib/site-packages",
f"{dir}/lib/python3.10/site-packages",
],
"PYTHONHOME": [],
"PY_LIBS": [
f"{dir}/Scripts/Lib",
f"{dir}/Scripts/Lib/site-packages",
f"{dir}/lib",
f"{dir}/lib/python3.10/site-packages",
],
"PY_PIP": [f"{dir}/Scripts", f"{dir}/bin"],
"PIP_INSTALLER_LOCATION": [], # [f"{dir}/python/get-pip.py"],
"TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"],
"HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"],
"COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args} --skip-torch-cuda-test'],
"SKIP_VENV": ["1"],
"SD_WEBUI_RESTARTING": ["1"],
}
if OS_NAME == "Windows":
env_entries["PATH"].append("C:/Windows/System32")
env_entries["PATH"].append("C:/Windows/System32/wbem")
env_entries["PYTHONNOUSERSITE"] = ["1"]
env_entries["PYTHON"] = [f"{dir}/python"]
env_entries["GIT"] = [f"{dir}/Library/bin/git"]
else:
env_entries["PATH"].append("/bin")
env_entries["PATH"].append("/usr/bin")
env_entries["PATH"].append("/usr/sbin")
env_entries["PYTHONNOUSERSITE"] = ["y"]
env_entries["PYTHON"] = [f"{dir}/bin/python"]
env_entries["GIT"] = [f"{dir}/bin/git"]
env_entries["venv_dir"] = ["-"]
if OS_NAME == "Darwin":
# based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/e26abf87ecd1eefd9ab0a198eee56f9c643e4001/webui-macos-env.sh
# hack - have to define these here, otherwise webui-macos-env.sh will overwrite COMMANDLINE_ARGS
env_entries["COMMANDLINE_ARGS"][0] += " --upcast-sampling --no-half-vae --use-cpu interrogate"
env_entries["PYTORCH_ENABLE_MPS_FALLBACK"] = ["1"]
cpu_name = str(subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]))
if "Intel" in cpu_name:
env_entries["TORCH_COMMAND"] = ["pip install torch==2.1.2 torchvision==0.16.2"]
else:
env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"]
else:
import torch
from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available
vram_usage_level = config.get("vram_usage_level", "balanced")
if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available():
env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu"
else:
c = local()
c.device_name = torch.cuda.get_device_name()
if needs_to_force_full_precision(c):
env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
if vram_usage_level == "low":
env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram"
elif vram_usage_level == "high":
env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram"
env = {}
for key, paths in env_entries.items():
paths = [p.replace("/", os.path.sep) for p in paths]
paths = os.pathsep.join(paths)
env[key] = paths
return env
def has_discrete_graphics_card():
system = OS_NAME
if system == "Windows":
try:
output = subprocess.check_output(
["wmic", "path", "win32_videocontroller", "get", "name"], stderr=subprocess.STDOUT
)
# Filter for discrete graphics cards (NVIDIA, AMD, etc.)
discrete_gpus = ["NVIDIA", "AMD", "ATI"]
return any(gpu in output.decode() for gpu in discrete_gpus)
except subprocess.CalledProcessError:
return False
elif system == "Linux":
try:
output = subprocess.check_output(["lspci"], stderr=subprocess.STDOUT)
# Check for discrete GPUs (NVIDIA, AMD)
discrete_gpus = ["NVIDIA", "AMD", "Advanced Micro Devices"]
return any(gpu in line for line in output.decode().splitlines() for gpu in discrete_gpus)
except subprocess.CalledProcessError:
return False
elif system == "Darwin": # macOS
try:
output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], stderr=subprocess.STDOUT)
# Check for discrete GPU in the output
return "NVIDIA" in output.decode() or "AMD" in output.decode()
except subprocess.CalledProcessError:
return False
return False
# https://stackoverflow.com/a/25134985
def kill(proc_pid):
process = psutil.Process(proc_pid)
for proc in process.children(recursive=True):
proc.kill()
process.kill()
def get_model_path_args():
args = []
for model_type, flag in MODELS_TO_OVERRIDE.items():
model_dir = get_model_dirs(model_type)[0]
args.append(f'{flag} "{model_dir}"')
return " ".join(args)

View File

@ -0,0 +1,654 @@
import os
import requests
from requests.exceptions import ConnectTimeout, ConnectionError
from typing import Union, List
from threading import local as Context
from threading import Thread
import uuid
import time
from copy import deepcopy
from sdkit.utils import base64_str_to_img, img_to_base64_str
WEBUI_HOST = "localhost"
WEBUI_PORT = "7860"
DEFAULT_WEBUI_OPTIONS = {
"show_progress_every_n_steps": 3,
"show_progress_grid": True,
"live_previews_enable": False,
"forge_additional_modules": [],
}
webui_opts: dict = None
curr_models = {
"stable-diffusion": None,
"vae": None,
}
def set_options(context, **kwargs):
changed_opts = {}
opts_mapping = {
"stream_image_progress": ("live_previews_enable", bool),
"stream_image_progress_interval": ("show_progress_every_n_steps", int),
"clip_skip": ("CLIP_stop_at_last_layers", int),
"clip_skip_sdxl": ("sdxl_clip_l_skip", bool),
"output_format": ("samples_format", str),
}
for ed_key, webui_key in opts_mapping.items():
webui_key, webui_type = webui_key
if ed_key in kwargs and (webui_opts is None or webui_opts.get(webui_key, False) != webui_type(kwargs[ed_key])):
changed_opts[webui_key] = webui_type(kwargs[ed_key])
if changed_opts:
changed_opts["sd_model_checkpoint"] = curr_models["stable-diffusion"]
print(f"Got options: {kwargs}. Sending options: {changed_opts}")
try:
res = webui_post("/sdapi/v1/options", json=changed_opts)
if res.status_code != 200:
raise Exception(res.text)
webui_opts.update(changed_opts)
except Exception as e:
print(f"Error setting options: {e}")
def ping(timeout=1):
"timeout (in seconds)"
global webui_opts
try:
res = webui_get("/internal/ping", timeout=timeout)
if res.status_code != 200:
raise ConnectTimeout(res.text)
if webui_opts is None:
try:
res = webui_post("/sdapi/v1/options", json=DEFAULT_WEBUI_OPTIONS)
if res.status_code != 200:
raise Exception(res.text)
except Exception as e:
print(f"Error setting options: {e}")
try:
res = webui_get("/sdapi/v1/options")
if res.status_code != 200:
raise Exception(res.text)
webui_opts = res.json()
except Exception as e:
print(f"Error getting options: {e}")
return True
except (ConnectTimeout, ConnectionError) as e:
raise TimeoutError(e)
def load_model(context, model_type, **kwargs):
model_path = context.model_paths[model_type]
if webui_opts is None:
print("Server not ready, can't set the model")
return
if model_type == "stable-diffusion":
model_name = os.path.basename(model_path)
model_name = os.path.splitext(model_name)[0]
print(f"setting sd model: {model_name}")
if curr_models[model_type] != model_name:
try:
res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name})
if res.status_code != 200:
raise Exception(res.text)
except Exception as e:
raise RuntimeError(
f"The engine failed to set the required options. Please check the logs in the command line window for more details."
)
curr_models[model_type] = model_name
elif model_type == "vae":
if curr_models[model_type] != model_path:
vae_model = [model_path] if model_path else []
opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model}
print("setting opts 2", opts)
try:
res = webui_post("/sdapi/v1/options", json=opts)
if res.status_code != 200:
raise Exception(res.text)
except Exception as e:
raise RuntimeError(
f"The engine failed to set the required options. Please check the logs in the command line window for more details."
)
curr_models[model_type] = model_path
def unload_model(context, model_type, **kwargs):
if model_type == "vae":
context.model_paths[model_type] = None
load_model(context, model_type)
def generate_images(
context: Context,
prompt: str = "",
negative_prompt: str = "",
seed: int = 42,
width: int = 512,
height: int = 512,
num_outputs: int = 1,
num_inference_steps: int = 25,
guidance_scale: float = 7.5,
distilled_guidance_scale: float = 3.5,
init_image=None,
init_image_mask=None,
control_image=None,
control_alpha=1.0,
controlnet_filter=None,
prompt_strength: float = 0.8,
preserve_init_image_color_profile=False,
strict_mask_border=False,
sampler_name: str = "euler_a",
scheduler_name: str = "simple",
hypernetwork_strength: float = 0,
tiling=None,
lora_alpha: Union[float, List[float]] = 0,
sampler_params={},
callback=None,
output_type="pil",
):
task_id = str(uuid.uuid4())
sampler_name = convert_ED_sampler_names(sampler_name)
controlnet_filter = convert_ED_controlnet_filter_name(controlnet_filter)
cmd = {
"force_task_id": task_id,
"prompt": prompt,
"negative_prompt": negative_prompt,
"sampler_name": sampler_name,
"scheduler": scheduler_name,
"steps": num_inference_steps,
"seed": seed,
"cfg_scale": guidance_scale,
"distilled_cfg_scale": distilled_guidance_scale,
"batch_size": num_outputs,
"width": width,
"height": height,
}
if init_image:
cmd["init_images"] = [init_image]
cmd["denoising_strength"] = prompt_strength
if init_image_mask:
cmd["mask"] = init_image_mask
cmd["include_init_images"] = True
cmd["inpainting_fill"] = 1
cmd["initial_noise_multiplier"] = 1
cmd["inpaint_full_res"] = 1
if context.model_paths.get("lora"):
lora_model = context.model_paths["lora"]
lora_model = lora_model if isinstance(lora_model, list) else [lora_model]
lora_alpha = lora_alpha if isinstance(lora_alpha, list) else [lora_alpha]
for lora, alpha in zip(lora_model, lora_alpha):
lora = os.path.basename(lora)
lora = os.path.splitext(lora)[0]
cmd["prompt"] += f" <lora:{lora}:{alpha}>"
if controlnet_filter and control_image and context.model_paths.get("controlnet"):
controlnet_model = context.model_paths["controlnet"]
model_hash = auto1111_hash(controlnet_model)
controlnet_model = os.path.basename(controlnet_model)
controlnet_model = os.path.splitext(controlnet_model)[0]
print(f"setting controlnet model: {controlnet_model}")
controlnet_model = f"{controlnet_model} [{model_hash}]"
cmd["alwayson_scripts"] = {
"controlnet": {
"args": [
{
"image": control_image,
"weight": control_alpha,
"module": controlnet_filter,
"model": controlnet_model,
"resize_mode": "Crop and Resize",
"threshold_a": 50,
"threshold_b": 130,
}
]
}
}
operation_to_apply = "img2img" if init_image else "txt2img"
stream_image_progress = webui_opts.get("live_previews_enable", False)
progress_thread = Thread(
target=image_progress_thread, args=(task_id, callback, stream_image_progress, num_outputs, num_inference_steps)
)
progress_thread.start()
print(f"task id: {task_id}")
print_request(operation_to_apply, cmd)
res = webui_post(f"/sdapi/v1/{operation_to_apply}", json=cmd)
if res.status_code == 200:
res = res.json()
else:
raise Exception(
"The engine failed while generating this image. Please check the logs in the command-line window for more details."
)
import json
print(json.loads(res["info"])["infotexts"])
images = res["images"]
if output_type == "pil":
images = [base64_str_to_img(img) for img in images]
elif output_type == "base64":
images = [base64_buffer_to_base64_img(img) for img in images]
return images
def filter_images(context: Context, images, filters, filter_params={}, input_type="pil"):
"""
* context: Context
* images: str or PIL.Image or list of str/PIL.Image - image to filter. if a string is passed, it needs to be a base64-encoded image
* filters: filter_type (string) or list of strings
* filter_params: dict
returns: [PIL.Image] - list of filtered images
"""
images = images if isinstance(images, list) else [images]
filters = filters if isinstance(filters, list) else [filters]
if "nsfw_checker" in filters:
filters.remove("nsfw_checker") # handled by ED directly
args = {}
controlnet_filters = []
print(filter_params)
for filter_name in filters:
params = filter_params.get(filter_name, {})
if filter_name == "gfpgan":
args["gfpgan_visibility"] = 1
if filter_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"):
args["upscaler_1"] = params.get("upscaler", "RealESRGAN_x4plus")
args["upscaling_resize"] = params.get("scale", 4)
if args["upscaler_1"] == "RealESRGAN_x4plus":
args["upscaler_1"] = "R-ESRGAN 4x+"
elif args["upscaler_1"] == "RealESRGAN_x4plus_anime_6B":
args["upscaler_1"] = "R-ESRGAN 4x+ Anime6B"
if filter_name == "codeformer":
args["codeformer_visibility"] = 1
args["codeformer_weight"] = params.get("codeformer_fidelity", 0.5)
if filter_name.startswith("controlnet_"):
filter_name = convert_ED_controlnet_filter_name(filter_name)
controlnet_filters.append(filter_name)
print(f"filtering {len(images)} images with {args}. {controlnet_filters=}")
if len(filters) > len(controlnet_filters):
filtered_images = extra_batch_images(images, input_type=input_type, **args)
else:
filtered_images = images
for filter_name in controlnet_filters:
filtered_images = controlnet_filter(filtered_images, module=filter_name, input_type=input_type)
return filtered_images
def get_url():
return f"//{WEBUI_HOST}:{WEBUI_PORT}/?__theme=dark"
def stop_rendering(context):
try:
res = webui_post("/sdapi/v1/interrupt")
if res.status_code != 200:
raise Exception(res.text)
except Exception as e:
print(f"Error interrupting webui: {e}")
def refresh_models():
def make_refresh_call(type):
try:
webui_post(f"/sdapi/v1/refresh-{type}")
except:
pass
try:
for type in ("checkpoints", "vae"):
t = Thread(target=make_refresh_call, args=(type,))
t.start()
except Exception as e:
print(f"Error refreshing models: {e}")
def list_controlnet_filters():
return [
"openpose",
"openpose_face",
"openpose_faceonly",
"openpose_hand",
"openpose_full",
"animal_openpose",
"densepose_parula (black bg & blue torso)",
"densepose (pruple bg & purple torso)",
"dw_openpose_full",
"mediapipe_face",
"instant_id_face_keypoints",
"InsightFace+CLIP-H (IPAdapter)",
"InsightFace (InstantID)",
"canny",
"mlsd",
"scribble_hed",
"scribble_hedsafe",
"scribble_pidinet",
"scribble_pidsafe",
"scribble_xdog",
"softedge_hed",
"softedge_hedsafe",
"softedge_pidinet",
"softedge_pidsafe",
"softedge_teed",
"normal_bae",
"depth_midas",
"normal_midas",
"depth_zoe",
"depth_leres",
"depth_leres++",
"depth_anything_v2",
"depth_anything",
"depth_hand_refiner",
"depth_marigold",
"lineart_coarse",
"lineart_realistic",
"lineart_anime",
"lineart_standard (from white bg & black line)",
"lineart_anime_denoise",
"reference_adain",
"reference_only",
"reference_adain+attn",
"tile_colorfix",
"tile_resample",
"tile_colorfix+sharp",
"CLIP-ViT-H (IPAdapter)",
"CLIP-G (Revision)",
"CLIP-G (Revision ignore prompt)",
"CLIP-ViT-bigG (IPAdapter)",
"InsightFace+CLIP-H (IPAdapter)",
"inpaint_only",
"inpaint_only+lama",
"inpaint_global_harmonious",
"seg_ufade20k",
"seg_ofade20k",
"seg_anime_face",
"seg_ofcoco",
"shuffle",
"segment",
"invert (from white bg & black line)",
"threshold",
"t2ia_sketch_pidi",
"t2ia_color_grid",
"recolor_intensity",
"recolor_luminance",
"blur_gaussian",
]
def controlnet_filter(images, module="none", processor_res=512, threshold_a=64, threshold_b=64, input_type="pil"):
if input_type == "pil":
images = [img_to_base64_str(x) for x in images]
payload = {
"controlnet_module": module,
"controlnet_input_images": images,
"controlnet_processor_res": processor_res,
"controlnet_threshold_a": threshold_a,
"controlnet_threshold_b": threshold_b,
}
res = webui_post("/controlnet/detect", json=payload)
res = res.json()
filtered_images = res["images"]
if input_type == "pil":
filtered_images = [base64_str_to_img(img) for img in filtered_images]
elif input_type == "base64":
filtered_images = [base64_buffer_to_base64_img(img) for img in filtered_images]
return filtered_images
def image_progress_thread(task_id, callback, stream_image_progress, total_images, total_steps):
from PIL import Image
last_preview_id = -1
EMPTY_IMAGE = Image.new("RGB", (1, 1))
while True:
res = webui_post(
f"/internal/progress",
json={"id_task": task_id, "live_preview": stream_image_progress, "id_live_preview": last_preview_id},
)
if res.status_code == 200:
res = res.json()
else:
raise RuntimeError(f"Unexpected progress response. Status code: {res.status_code}. Res: {res.text}")
last_preview_id = res["id_live_preview"]
if res["progress"] is not None:
step_num = int(res["progress"] * total_steps)
if res["live_preview"] is not None:
img = res["live_preview"]
img = base64_str_to_img(img)
images = [EMPTY_IMAGE] * total_images
images[0] = img
else:
images = None
callback(images, step_num)
if res["completed"] == True:
print("Complete!")
break
time.sleep(0.5)
def webui_get(uri, *args, **kwargs):
url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}"
return requests.get(url, *args, **kwargs)
def webui_post(uri, *args, **kwargs):
url = f"http://{WEBUI_HOST}:{WEBUI_PORT}{uri}"
return requests.post(url, *args, **kwargs)
def print_request(operation_to_apply, args):
args = deepcopy(args)
if "init_images" in args:
args["init_images"] = ["img" for _ in args["init_images"]]
if "mask" in args:
args["mask"] = "mask_img"
controlnet_args = args.get("alwayson_scripts", {}).get("controlnet", {}).get("args", [])
if controlnet_args:
controlnet_args[0]["image"] = "control_image"
print(f"operation: {operation_to_apply}, args: {args}")
def auto1111_hash(file_path):
import hashlib
with open(file_path, "rb") as f:
f.seek(0x100000)
b = f.read(0x10000)
return hashlib.sha256(b).hexdigest()[:8]
def extra_batch_images(
images, # list of PIL images
name_list=None, # list of image names
resize_mode=0,
show_extras_results=True,
gfpgan_visibility=0,
codeformer_visibility=0,
codeformer_weight=0,
upscaling_resize=2,
upscaling_resize_w=512,
upscaling_resize_h=512,
upscaling_crop=True,
upscaler_1="None",
upscaler_2="None",
extras_upscaler_2_visibility=0,
upscale_first=False,
use_async=False,
input_type="pil",
):
if name_list is not None:
if len(name_list) != len(images):
raise RuntimeError("len(images) != len(name_list)")
else:
name_list = [f"image{i + 1:05}" for i in range(len(images))]
if input_type == "pil":
images = [img_to_base64_str(x) for x in images]
image_list = []
for name, image in zip(name_list, images):
image_list.append({"data": image, "name": name})
payload = {
"resize_mode": resize_mode,
"show_extras_results": show_extras_results,
"gfpgan_visibility": gfpgan_visibility,
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
"upscaling_resize": upscaling_resize,
"upscaling_resize_w": upscaling_resize_w,
"upscaling_resize_h": upscaling_resize_h,
"upscaling_crop": upscaling_crop,
"upscaler_1": upscaler_1,
"upscaler_2": upscaler_2,
"extras_upscaler_2_visibility": extras_upscaler_2_visibility,
"upscale_first": upscale_first,
"imageList": image_list,
}
res = webui_post("/sdapi/v1/extra-batch-images", json=payload)
if res.status_code == 200:
res = res.json()
else:
raise Exception(
"The engine failed while filtering this image. Please check the logs in the command-line window for more details."
)
images = res["images"]
if input_type == "pil":
images = [base64_str_to_img(img) for img in images]
elif input_type == "base64":
images = [base64_buffer_to_base64_img(img) for img in images]
return images
def base64_buffer_to_base64_img(img):
output_format = webui_opts.get("samples_format", "jpeg")
mime_type = f"image/{output_format.lower()}"
return f"data:{mime_type};base64," + img
def convert_ED_sampler_names(sampler_name):
name_mapping = {
"dpmpp_2m": "DPM++ 2M",
"dpmpp_sde": "DPM++ SDE",
"dpmpp_2m_sde": "DPM++ 2M SDE",
"dpmpp_2m_sde_heun": "DPM++ 2M SDE Heun",
"dpmpp_2s_a": "DPM++ 2S a",
"dpmpp_3m_sde": "DPM++ 3M SDE",
"euler_a": "Euler a",
"euler": "Euler",
"lms": "LMS",
"heun": "Heun",
"dpm2": "DPM2",
"dpm2_a": "DPM2 a",
"dpm_fast": "DPM fast",
"dpm_adaptive": "DPM adaptive",
"restart": "Restart",
"heun_pp2": "HeunPP2",
"ipndm": "IPNDM",
"ipndm_v": "IPNDM_V",
"deis": "DEIS",
"ddim": "DDIM",
"ddim_cfgpp": "DDIM CFG++",
"plms": "PLMS",
"unipc": "UniPC",
"lcm": "LCM",
"ddpm": "DDPM",
"forge_flux_realistic": "[Forge] Flux Realistic",
"forge_flux_realistic_slow": "[Forge] Flux Realistic (Slow)",
# deprecated samplers in 3.5
"dpm_solver_stability": None,
"unipc_snr": None,
"unipc_tu": None,
"unipc_snr_2": None,
"unipc_tu_2": None,
"unipc_tq": None,
}
return name_mapping.get(sampler_name)
def convert_ED_controlnet_filter_name(filter):
if filter is None:
return None
def cn(n):
if n.startswith("controlnet_"):
return n[len("controlnet_") :]
return n
mapping = {
"controlnet_scribble_hedsafe": None,
"controlnet_scribble_pidsafe": None,
"controlnet_softedge_pidsafe": "controlnet_softedge_pidisafe",
"controlnet_normal_bae": "controlnet_normalbae",
"controlnet_segment": None,
}
if isinstance(filter, list):
return [cn(mapping.get(f, f)) for f in filter]
return cn(mapping.get(filter, filter))

View File

@ -33,4 +33,3 @@ class Bucket(BucketBase):
class Config: class Config:
orm_mode = True orm_mode = True

View File

@ -8,7 +8,7 @@ from easydiffusion import app
from easydiffusion.types import ModelsData from easydiffusion.types import ModelsData
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit import Context from sdkit import Context
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db from sdkit.models import scan_model, download_model, get_model_info_from_db
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
from sdkit.utils import hash_file_quick from sdkit.utils import hash_file_quick
from sdkit.models.model_loader.embeddings import get_embedding_token from sdkit.models.model_loader.embeddings import get_embedding_token
@ -25,19 +25,19 @@ KNOWN_MODEL_TYPES = [
"controlnet", "controlnet",
] ]
MODEL_EXTENSIONS = { MODEL_EXTENSIONS = {
"stable-diffusion": [".ckpt", ".safetensors"], "stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"],
"vae": [".vae.pt", ".ckpt", ".safetensors"], "vae": [".vae.pt", ".ckpt", ".safetensors", ".sft"],
"hypernetwork": [".pt", ".safetensors"], "hypernetwork": [".pt", ".safetensors", ".sft"],
"gfpgan": [".pth"], "gfpgan": [".pth"],
"realesrgan": [".pth"], "realesrgan": [".pth"],
"lora": [".ckpt", ".safetensors", ".pt"], "lora": [".ckpt", ".safetensors", ".sft", ".pt"],
"codeformer": [".pth"], "codeformer": [".pth"],
"embeddings": [".pt", ".bin", ".safetensors"], "embeddings": [".pt", ".bin", ".safetensors", ".sft"],
"controlnet": [".pth", ".safetensors"], "controlnet": [".pth", ".safetensors", ".sft"],
} }
DEFAULT_MODELS = { DEFAULT_MODELS = {
"stable-diffusion": [ "stable-diffusion": [
{"file_name": "sd-v1-5.safetensors", "model_id": "1.5-pruned-emaonly-fp16"}, {"file_name": "sd-v1-4.ckpt", "model_id": "1.4"},
], ],
"gfpgan": [ "gfpgan": [
{"file_name": "GFPGANv1.4.pth", "model_id": "1.4"}, {"file_name": "GFPGANv1.4.pth", "model_id": "1.4"},
@ -51,6 +51,16 @@ DEFAULT_MODELS = {
], ],
} }
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"]
ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility
"stable-diffusion": "Stable-diffusion",
"vae": "VAE",
"hypernetwork": "hypernetworks",
"codeformer": "Codeformer",
"gfpgan": "GFPGAN",
"realesrgan": "RealESRGAN",
"lora": "Lora",
"controlnet": "ControlNet",
}
known_models = {} known_models = {}
@ -63,6 +73,7 @@ def init():
def load_default_models(context: Context): def load_default_models(context: Context):
from easydiffusion import runtime from easydiffusion import runtime
from easydiffusion.backend_manager import backend
runtime.set_vram_optimizations(context) runtime.set_vram_optimizations(context)
@ -70,7 +81,7 @@ def load_default_models(context: Context):
for model_type in MODELS_TO_LOAD_ON_START: for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False) context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False)
try: try:
load_model( backend.load_model(
context, context,
model_type, model_type,
scan_model=context.model_paths[model_type] != None scan_model=context.model_paths[model_type] != None
@ -92,9 +103,11 @@ def load_default_models(context: Context):
def unload_all(context: Context): def unload_all(context: Context):
from easydiffusion.backend_manager import backend
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
unload_model(context, model_type) backend.unload_model(context, model_type)
if model_type in context.model_load_errors: if hasattr(context, "model_load_errors") and model_type in context.model_load_errors:
del context.model_load_errors[model_type] del context.model_load_errors[model_type]
@ -119,12 +132,12 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
default_models = DEFAULT_MODELS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig() config = app.getConfig()
model_dir = os.path.join(app.MODELS_DIR, model_type)
if not model_name: # When None try user configured model. if not model_name: # When None try user configured model.
# config = getConfig() # config = getConfig()
if "model" in config and model_type in config["model"]: if "model" in config and model_type in config["model"]:
model_name = config["model"][model_type] model_name = config["model"][model_type]
for model_dir in get_model_dirs(model_type):
if model_name: if model_name:
# Check models directory # Check models directory
model_path = os.path.join(model_dir, model_name) model_path = os.path.join(model_dir, model_name)
@ -154,6 +167,8 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []): def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
from easydiffusion.backend_manager import backend
models_to_reload = { models_to_reload = {
model_type: path model_type: path
for model_type, path in models_data.model_paths.items() for model_type, path in models_data.model_paths.items()
@ -175,7 +190,7 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models
for model_type, model_path_in_req in models_to_reload.items(): for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn = backend.unload_model if context.model_paths[model_type] is None else backend.load_model
extra_params = models_data.model_params.get(model_type, {}) extra_params = models_data.model_params.get(model_type, {})
try: try:
action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already
@ -183,14 +198,23 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models
del context.model_load_errors[model_type] del context.model_load_errors[model_type]
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
if action_fn == load_model: if action_fn == backend.load_model:
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
def resolve_model_paths(models_data: ModelsData): def resolve_model_paths(models_data: ModelsData):
model_paths = models_data.model_paths model_paths = models_data.model_paths
skip_models = cn_filters + [
"latent_upscaler",
"nsfw_checker",
"esrgan_4x",
"lanczos",
"nearest",
"scunet",
"swinir",
]
for model_type in model_paths: for model_type in model_paths:
skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"]
if model_type in skip_models: # doesn't use model paths if model_type in skip_models: # doesn't use model paths
continue continue
if model_type == "codeformer" and model_paths[model_type]: if model_type == "codeformer" and model_paths[model_type]:
@ -225,7 +249,8 @@ def download_default_models_if_necessary():
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
model_path = os.path.join(app.MODELS_DIR, model_type, file_name) model_dir = get_model_dirs(model_type)[0]
model_path = os.path.join(model_dir, file_name)
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
other_models_exist = any_model_exists(model_type) and skip_if_others_exist other_models_exist = any_model_exists(model_type) and skip_if_others_exist
@ -245,13 +270,15 @@ def migrate_legacy_model_location():
file_name = model["file_name"] file_name = model["file_name"]
legacy_path = os.path.join(app.SD_DIR, file_name) legacy_path = os.path.join(app.SD_DIR, file_name)
if os.path.exists(legacy_path): if os.path.exists(legacy_path):
shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name)) model_dir = get_model_dirs(model_type)[0]
shutil.move(legacy_path, os.path.join(model_dir, file_name))
def any_model_exists(model_type: str) -> bool: def any_model_exists(model_type: str) -> bool:
extensions = MODEL_EXTENSIONS.get(model_type, []) extensions = MODEL_EXTENSIONS.get(model_type, [])
for model_dir in get_model_dirs(model_type):
for ext in extensions: for ext in extensions:
if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)): if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
return True return True
return False return False
@ -259,7 +286,7 @@ def any_model_exists(model_type: str) -> bool:
def make_model_folders(): def make_model_folders():
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type) model_dir_path = get_model_dirs(model_type)[0]
try: try:
os.makedirs(model_dir_path, exist_ok=True) os.makedirs(model_dir_path, exist_ok=True)
@ -322,6 +349,10 @@ def is_malicious_model(file_path):
def getModels(scan_for_malicious: bool = True): def getModels(scan_for_malicious: bool = True):
from easydiffusion.backend_manager import backend
backend.refresh_models()
models = { models = {
"options": { "options": {
"stable-diffusion": [], "stable-diffusion": [],
@ -331,19 +362,19 @@ def getModels(scan_for_malicious: bool = True):
"codeformer": [{"codeformer": "CodeFormer"}], "codeformer": [{"codeformer": "CodeFormer"}],
"embeddings": [], "embeddings": [],
"controlnet": [ "controlnet": [
{"control_v11p_sd15_canny": "Canny (*)"}, # {"control_v11p_sd15_canny": "Canny (*)"},
{"control_v11p_sd15_openpose": "OpenPose (*)"}, # {"control_v11p_sd15_openpose": "OpenPose (*)"},
{"control_v11p_sd15_normalbae": "Normal BAE (*)"}, # {"control_v11p_sd15_normalbae": "Normal BAE (*)"},
{"control_v11f1p_sd15_depth": "Depth (*)"}, # {"control_v11f1p_sd15_depth": "Depth (*)"},
{"control_v11p_sd15_scribble": "Scribble"}, # {"control_v11p_sd15_scribble": "Scribble"},
{"control_v11p_sd15_softedge": "Soft Edge"}, # {"control_v11p_sd15_softedge": "Soft Edge"},
{"control_v11p_sd15_inpaint": "Inpaint"}, # {"control_v11p_sd15_inpaint": "Inpaint"},
{"control_v11p_sd15_lineart": "Line Art"}, # {"control_v11p_sd15_lineart": "Line Art"},
{"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, # {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"},
{"control_v11p_sd15_mlsd": "Straight Lines"}, # {"control_v11p_sd15_mlsd": "Straight Lines"},
{"control_v11p_sd15_seg": "Segment"}, # {"control_v11p_sd15_seg": "Segment"},
{"control_v11e_sd15_shuffle": "Shuffle"}, # {"control_v11e_sd15_shuffle": "Shuffle"},
{"control_v11f1e_sd15_tile": "Tile"}, # {"control_v11f1e_sd15_tile": "Tile"},
], ],
}, },
} }
@ -358,6 +389,9 @@ def getModels(scan_for_malicious: bool = True):
tree = list(default_entries) tree = list(default_entries)
if not os.path.exists(directory):
return tree
for entry in sorted( for entry in sorted(
os.scandir(directory), os.scandir(directory),
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
@ -380,6 +414,8 @@ def getModels(scan_for_malicious: bool = True):
model_id = entry.name[: -len(matching_suffix)] model_id = entry.name[: -len(matching_suffix)]
if callable(nameFilter): if callable(nameFilter):
model_id = nameFilter(model_id) model_id = nameFilter(model_id)
if model_id is None:
continue
model_exists = False model_exists = False
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
@ -400,14 +436,15 @@ def getModels(scan_for_malicious: bool = True):
nonlocal models_scanned nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, []) model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type) models_dirs = get_model_dirs(model_type)
if not os.path.exists(models_dir): if not os.path.exists(models_dirs[0]):
os.makedirs(models_dir) os.makedirs(models_dirs[0])
for model_dir in models_dirs:
try: try:
default_tree = models["options"].get(model_type, []) default_tree = models["options"].get(model_type, [])
models["options"][model_type] = scan_directory( models["options"][model_type] = scan_directory(
models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter model_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
) )
except MaliciousModelException as e: except MaliciousModelException as e:
models["scan-error"] = str(e) models["scan-error"] = str(e)
@ -418,7 +455,7 @@ def getModels(scan_for_malicious: bool = True):
listModels(model_type="stable-diffusion") listModels(model_type="stable-diffusion")
listModels(model_type="vae") listModels(model_type="vae")
listModels(model_type="hypernetwork") listModels(model_type="hypernetwork")
listModels(model_type="gfpgan") listModels(model_type="gfpgan", nameFilter=lambda x: (x if "gfpgan" in x.lower() else None))
listModels(model_type="lora") listModels(model_type="lora")
listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="embeddings", nameFilter=get_embedding_token)
listModels(model_type="controlnet") listModels(model_type="controlnet")
@ -427,3 +464,20 @@ def getModels(scan_for_malicious: bool = True):
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
return models return models
def get_model_dirs(model_type: str, base_dir=None):
"Returns the possible model directory paths for the given model type. Mainly used for WebUI compatibility"
if base_dir is None:
base_dir = app.MODELS_DIR
dirs = [os.path.join(base_dir, model_type)]
if model_type in ALTERNATE_FOLDER_NAMES:
alt_dir = ALTERNATE_FOLDER_NAMES[model_type]
alt_dir = os.path.join(base_dir, alt_dir)
if os.path.exists(alt_dir) and os.path.isdir(alt_dir):
dirs.append(alt_dir)
return dirs

View File

@ -3,8 +3,6 @@ import os
import platform import platform
from importlib.metadata import version as pkg_version from importlib.metadata import version as pkg_version
from sdkit.utils import log
from easydiffusion import app from easydiffusion import app
# future home of scripts/check_modules.py # future home of scripts/check_modules.py
@ -50,6 +48,8 @@ def is_installed(module_name) -> bool:
def install(module_name): def install(module_name):
from easydiffusion.utils import log
if is_installed(module_name): if is_installed(module_name):
log.info(f"{module_name} has already been installed!") log.info(f"{module_name} has already been installed!")
return return
@ -79,6 +79,8 @@ def install(module_name):
def uninstall(module_name): def uninstall(module_name):
from easydiffusion.utils import log
if not is_installed(module_name): if not is_installed(module_name):
log.info(f"{module_name} hasn't been installed!") log.info(f"{module_name} hasn't been installed!")
return return

View File

@ -1,4 +1,5 @@
""" """
(OUTDATED DOC)
A runtime that runs on a specific device (in a thread). A runtime that runs on a specific device (in a thread).
It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context. It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context.
@ -6,42 +7,35 @@ It can run various tasks like image generation, image filtering, model merge etc
This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function. This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function.
""" """
from easydiffusion import device_manager context = None
from easydiffusion.utils import log
from sdkit import Context
from sdkit.utils import get_device_usage
context = Context() # thread-local
"""
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device): def init(device):
""" """
Initializes the fields that will be bound to this runtime's context, and sets the current torch device Initializes the fields that will be bound to this runtime's context, and sets the current torch device
""" """
global context
from easydiffusion import device_manager
from easydiffusion.backend_manager import backend
from easydiffusion.app import getConfig
context = backend.create_context()
context.stop_processing = False context.stop_processing = False
context.temp_images = {} context.temp_images = {}
context.partial_x_samples = None context.partial_x_samples = None
context.model_load_errors = {} context.model_load_errors = {}
context.enable_codeformer = True context.enable_codeformer = True
from easydiffusion import app
app_config = app.getConfig()
context.test_diffusers = app_config.get("use_v3_engine", True)
log.info("Device usage during initialization:")
get_device_usage(device, log_info=True, process_usage_only=False)
device_manager.device_init(context, device) device_manager.device_init(context, device)
def set_vram_optimizations(context: Context): def set_vram_optimizations(context):
from easydiffusion import app from easydiffusion.app import getConfig
config = app.getConfig() config = getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced") vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level: if vram_usage_level != context.vram_usage_level:

View File

@ -2,6 +2,7 @@
Notes: Notes:
async endpoints always run on the main thread. Without they run on the thread pool. async endpoints always run on the main thread. Without they run on the thread pool.
""" """
import datetime import datetime
import mimetypes import mimetypes
import os import os
@ -20,6 +21,7 @@ from easydiffusion.types import (
OutputFormatData, OutputFormatData,
SaveToDiskData, SaveToDiskData,
convert_legacy_render_req_to_new, convert_legacy_render_req_to_new,
convert_legacy_controlnet_filter_name,
) )
from easydiffusion.utils import log from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
@ -67,7 +69,9 @@ class SetAppConfigRequest(BaseModel, extra=Extra.allow):
listen_to_network: bool = None listen_to_network: bool = None
listen_port: int = None listen_port: int = None
use_v3_engine: bool = True use_v3_engine: bool = True
backend: str = "ed_diffusers"
models_dir: str = None models_dir: str = None
vram_usage_level: str = "balanced"
def init(): def init():
@ -155,6 +159,12 @@ def init():
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit("Application shutting down.") task_manager.current_state_error = SystemExit("Application shutting down.")
@server_api.on_event("startup")
def start_event():
from easydiffusion.app import open_browser
open_browser()
# API implementations # API implementations
def set_app_config_internal(req: SetAppConfigRequest): def set_app_config_internal(req: SetAppConfigRequest):
@ -176,8 +186,10 @@ def set_app_config_internal(req: SetAppConfigRequest):
config["net"] = {} config["net"] = {}
config["net"]["listen_port"] = int(req.listen_port) config["net"]["listen_port"] = int(req.listen_port)
config["use_v3_engine"] = req.use_v3_engine config["use_v3_engine"] = req.backend == "ed_diffusers"
config["backend"] = req.backend
config["models_dir"] = req.models_dir config["models_dir"] = req.models_dir
config["vram_usage_level"] = req.vram_usage_level
for property, property_value in req.dict().items(): for property, property_value in req.dict().items():
if property_value is not None and property not in req.__fields__ and property not in PROTECTED_CONFIG_KEYS: if property_value is not None and property not in req.__fields__ and property not in PROTECTED_CONFIG_KEYS:
@ -216,6 +228,8 @@ def read_web_data_internal(key: str = None, **kwargs):
return JSONResponse(config, headers=NOCACHE_HEADERS) return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == "system_info": elif key == "system_info":
from easydiffusion.backend_manager import backend
config = app.getConfig() config = app.getConfig()
output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
@ -226,6 +240,7 @@ def read_web_data_internal(key: str = None, **kwargs):
"default_output_dir": output_dir, "default_output_dir": output_dir,
"enforce_output_dir": ("force_save_path" in config), "enforce_output_dir": ("force_save_path" in config),
"enforce_output_metadata": ("force_save_metadata" in config), "enforce_output_metadata": ("force_save_metadata" in config),
"backend_url": backend.get_url(),
} }
system_info["devices"]["config"] = config.get("render_devices", "auto") system_info["devices"]["config"] = config.get("render_devices", "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS) return JSONResponse(system_info, headers=NOCACHE_HEADERS)
@ -309,6 +324,15 @@ def filter_internal(req: dict):
output_format: OutputFormatData = OutputFormatData.parse_obj(req) output_format: OutputFormatData = OutputFormatData.parse_obj(req)
save_data: SaveToDiskData = SaveToDiskData.parse_obj(req) save_data: SaveToDiskData = SaveToDiskData.parse_obj(req)
filter_req.filter = convert_legacy_controlnet_filter_name(filter_req.filter)
for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"):
if models_data.model_paths.get(model_name):
if model_name not in filter_req.filter_params:
filter_req.filter_params[model_name] = {}
filter_req.filter_params[model_name]["upscaler"] = models_data.model_paths[model_name]
# enqueue the task # enqueue the task
task = FilterTask(filter_req, task_data, models_data, output_format, save_data) task = FilterTask(filter_req, task_data, models_data, output_format, save_data)
return enqueue_task(task) return enqueue_task(task)
@ -342,15 +366,13 @@ def model_merge_internal(req: dict):
mergeReq: MergeRequest = MergeRequest.parse_obj(req) mergeReq: MergeRequest = MergeRequest.parse_obj(req)
sd_model_dir = model_manager.get_model_dir("stable-diffusion")[0]
merge_models( merge_models(
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
mergeReq.ratio, mergeReq.ratio,
os.path.join( os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)),
app.MODELS_DIR,
"stable-diffusion",
filename_regex.sub("_", mergeReq.out_path),
),
mergeReq.use_fp16, mergeReq.use_fp16,
) )
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)

View File

@ -4,6 +4,7 @@ Notes:
Use weak_thread_data to store all other data using weak keys. Use weak_thread_data to store all other data using weak keys.
This will allow for garbage collection after the thread dies. This will allow for garbage collection after the thread dies.
""" """
import json import json
import traceback import traceback
@ -19,7 +20,6 @@ import torch
from easydiffusion import device_manager from easydiffusion import device_manager
from easydiffusion.tasks import Task from easydiffusion.tasks import Task
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit.utils import gc
THREAD_NAME_PREFIX = "" THREAD_NAME_PREFIX = ""
ERR_LOCK_FAILED = " failed to acquire lock within timeout." ERR_LOCK_FAILED = " failed to acquire lock within timeout."
@ -233,6 +233,8 @@ def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from easydiffusion import model_manager, runtime from easydiffusion import model_manager, runtime
from easydiffusion.backend_manager import backend
from requests import ConnectionError
try: try:
runtime.init(device) runtime.init(device)
@ -244,8 +246,17 @@ def thread_render(device):
} }
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
model_manager.load_default_models(runtime.context)
while True:
try:
if backend.ping(timeout=1):
break
time.sleep(1)
except (TimeoutError, ConnectionError):
time.sleep(1)
model_manager.load_default_models(runtime.context)
current_state = ServerStates.Online current_state = ServerStates.Online
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
@ -291,7 +302,6 @@ def thread_render(device):
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
finally: finally:
gc(runtime.context)
task.lock.release() task.lock.release()
keep_task_alive(task) keep_task_alive(task)

View File

@ -5,9 +5,7 @@ import time
from numpy import base_repr from numpy import base_repr
from sdkit.filter import apply_filters from sdkit.utils import img_to_base64_str, log, save_images, base64_str_to_img
from sdkit.models import load_model
from sdkit.utils import img_to_base64_str, get_image, log, save_images
from easydiffusion import model_manager, runtime from easydiffusion import model_manager, runtime
from easydiffusion.types import ( from easydiffusion.types import (
@ -19,6 +17,7 @@ from easydiffusion.types import (
TaskData, TaskData,
GenerateImageRequest, GenerateImageRequest,
) )
from easydiffusion.utils import filter_nsfw
from easydiffusion.utils.save_utils import format_folder_name from easydiffusion.utils.save_utils import format_folder_name
from .task import Task from .task import Task
@ -47,7 +46,9 @@ class FilterTask(Task):
# convert to multi-filter format, if necessary # convert to multi-filter format, if necessary
if isinstance(req.filter, str): if isinstance(req.filter, str):
if req.filter not in req.filter_params:
req.filter_params = {req.filter: req.filter_params} req.filter_params = {req.filter: req.filter_params}
req.filter = [req.filter] req.filter = [req.filter]
if not isinstance(req.image, list): if not isinstance(req.image, list):
@ -57,6 +58,7 @@ class FilterTask(Task):
"Runs the image filtering task on the assigned thread" "Runs the image filtering task on the assigned thread"
from easydiffusion import app from easydiffusion import app
from easydiffusion.backend_manager import backend
context = runtime.context context = runtime.context
@ -66,15 +68,24 @@ class FilterTask(Task):
print_task_info(self.request, self.models_data, self.output_format, self.save_data) print_task_info(self.request, self.models_data, self.output_format, self.save_data)
if isinstance(self.request.image, list): has_nsfw_filter = "nsfw_filter" in self.request.filter
images = [get_image(img) for img in self.request.image]
else:
images = get_image(self.request.image)
images = filter_images(context, images, self.request.filter, self.request.filter_params)
output_format = self.output_format output_format = self.output_format
backend.set_options(
context,
output_format=output_format.output_format,
output_quality=output_format.output_quality,
output_lossless=output_format.output_lossless,
)
images = backend.filter_images(
context, self.request.image, self.request.filter, self.request.filter_params, input_type="base64"
)
if has_nsfw_filter:
images = filter_nsfw(images)
if self.save_data.save_to_disk_path is not None: if self.save_data.save_to_disk_path is not None:
app_config = app.getConfig() app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id") folder_format = app_config.get("folder_format", "$id")
@ -85,8 +96,9 @@ class FilterTask(Task):
save_dir_path = os.path.join( save_dir_path = os.path.join(
self.save_data.save_to_disk_path, format_folder_name(folder_format, dummy_req, self.task_data) self.save_data.save_to_disk_path, format_folder_name(folder_format, dummy_req, self.task_data)
) )
images_pil = [base64_str_to_img(img) for img in images]
save_images( save_images(
images, images_pil,
save_dir_path, save_dir_path,
file_name=img_id, file_name=img_id,
output_format=output_format.output_format, output_format=output_format.output_format,
@ -94,13 +106,6 @@ class FilterTask(Task):
output_lossless=output_format.output_lossless, output_lossless=output_format.output_lossless,
) )
images = [
img_to_base64_str(
img, output_format.output_format, output_format.output_quality, output_format.output_lossless
)
for img in images
]
res = FilterImageResponse(self.request, self.models_data, images=images) res = FilterImageResponse(self.request, self.models_data, images=images)
res = res.json() res = res.json()
self.buffer_queue.put(json.dumps(res)) self.buffer_queue.put(json.dumps(res))
@ -110,46 +115,6 @@ class FilterTask(Task):
self.response = res self.response = res
def filter_images(context, images, filters, filter_params={}):
filters = filters if isinstance(filters, list) else [filters]
for filter_name in filters:
params = filter_params.get(filter_name, {})
previous_state = before_filter(context, filter_name, params)
try:
images = apply_filters(context, filter_name, images, **params)
finally:
after_filter(context, filter_name, params, previous_state)
return images
def before_filter(context, filter_name, filter_params):
if filter_name == "codeformer":
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
upscale_faces = filter_params.get("upscale_faces", False)
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths.get("realesrgan")
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
return prev_realesrgan_path
def after_filter(context, filter_name, filter_params, previous_state):
if filter_name == "codeformer":
prev_realesrgan_path = previous_state
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
def print_task_info( def print_task_info(
req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData
): ):

View File

@ -2,26 +2,23 @@ import json
import pprint import pprint
import queue import queue
import time import time
from PIL import Image
from easydiffusion import model_manager, runtime from easydiffusion import model_manager, runtime
from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData, SaveToDiskData from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData, SaveToDiskData
from easydiffusion.types import Image as ResponseImage from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import GenerateImageResponse, RenderTaskData, UserInitiatedStop from easydiffusion.types import GenerateImageResponse, RenderTaskData
from easydiffusion.utils import get_printable_request, log, save_images_to_disk from easydiffusion.utils import get_printable_request, log, save_images_to_disk, filter_nsfw
from sdkit.generate import generate_images
from sdkit.utils import ( from sdkit.utils import (
diffusers_latent_samples_to_images,
gc,
img_to_base64_str, img_to_base64_str,
base64_str_to_img,
img_to_buffer, img_to_buffer,
latent_samples_to_images,
resize_img, resize_img,
get_image, get_image,
log, log,
) )
from .task import Task from .task import Task
from .filter_images import filter_images
class RenderTask(Task): class RenderTask(Task):
@ -51,15 +48,13 @@ class RenderTask(Task):
"Runs the image generation task on the assigned thread" "Runs the image generation task on the assigned thread"
from easydiffusion import task_manager, app from easydiffusion import task_manager, app
from easydiffusion.backend_manager import backend
context = runtime.context context = runtime.context
config = app.getConfig() config = app.getConfig()
if config.get("block_nsfw", False): # override if set on the server if config.get("block_nsfw", False): # override if set on the server
self.task_data.block_nsfw = True self.task_data.block_nsfw = True
if "nsfw_checker" not in self.task_data.filters:
self.task_data.filters.append("nsfw_checker")
self.models_data.model_paths["nsfw_checker"] = "nsfw_checker"
def step_callback(): def step_callback():
task_manager.keep_task_alive(self) task_manager.keep_task_alive(self)
@ -68,7 +63,7 @@ class RenderTask(Task):
if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance( if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance(
self.error, StopAsyncIteration self.error, StopAsyncIteration
): ):
context.stop_processing = True backend.stop_rendering(context)
if isinstance(task_manager.current_state_error, StopAsyncIteration): if isinstance(task_manager.current_state_error, StopAsyncIteration):
self.error = task_manager.current_state_error self.error = task_manager.current_state_error
task_manager.current_state_error = None task_manager.current_state_error = None
@ -78,11 +73,7 @@ class RenderTask(Task):
model_manager.resolve_model_paths(self.models_data) model_manager.resolve_model_paths(self.models_data)
models_to_force_reload = [] models_to_force_reload = []
if ( if runtime.set_vram_optimizations(context) or self.has_param_changed(context, "clip_skip"):
runtime.set_vram_optimizations(context)
or self.has_param_changed(context, "clip_skip")
or self.trt_needs_reload(context)
):
models_to_force_reload.append("stable-diffusion") models_to_force_reload.append("stable-diffusion")
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload) model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
@ -99,10 +90,11 @@ class RenderTask(Task):
self.buffer_queue, self.buffer_queue,
self.temp_images, self.temp_images,
step_callback, step_callback,
self,
) )
def has_param_changed(self, context, param_name): def has_param_changed(self, context, param_name):
if not context.test_diffusers: if not getattr(context, "test_diffusers", False):
return False return False
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
return True return True
@ -111,29 +103,6 @@ class RenderTask(Task):
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
return model["params"].get(param_name) != new_val return model["params"].get(param_name) != new_val
def trt_needs_reload(self, context):
if not context.test_diffusers:
return False
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
return True
model = context.models["stable-diffusion"]
# curr_convert_to_trt = model["params"].get("convert_to_tensorrt")
new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False)
pipe = model["default"]
is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr(
pipe.unet, "_allocate_trt_buffers_backup"
)
if new_convert_to_trt and not is_trt_loaded:
return True
curr_build_config = model["params"].get("trt_build_config")
new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {})
return new_convert_to_trt and curr_build_config != new_build_config
def make_images( def make_images(
context, context,
@ -145,12 +114,21 @@ def make_images(
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
task,
): ):
context.stop_processing = False
print_task_info(req, task_data, models_data, output_format, save_data) print_task_info(req, task_data, models_data, output_format, save_data)
images, seeds = make_images_internal( images, seeds = make_images_internal(
context, req, task_data, models_data, output_format, save_data, data_queue, task_temp_images, step_callback context,
req,
task_data,
models_data,
output_format,
save_data,
data_queue,
task_temp_images,
step_callback,
task,
) )
res = GenerateImageResponse( res = GenerateImageResponse(
@ -170,7 +148,9 @@ def print_task_info(
output_format: OutputFormatData, output_format: OutputFormatData,
save_data: SaveToDiskData, save_data: SaveToDiskData,
): ):
req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[") req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace(
"[", "\["
)
task_str = pprint.pformat(task_data.dict()).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
@ -178,7 +158,7 @@ def print_task_info(
log.info(f"request: {req_str}") log.info(f"request: {req_str}")
log.info(f"task data: {task_str}") log.info(f"task data: {task_str}")
# log.info(f"models data: {models_data}") log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}") log.info(f"output format: {output_format}")
log.info(f"save data: {save_data}") log.info(f"save data: {save_data}")
@ -193,26 +173,41 @@ def make_images_internal(
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
task,
): ):
images, user_stopped = generate_images_internal( from easydiffusion.backend_manager import backend
# prep the nsfw_filter
if task_data.block_nsfw:
filter_nsfw([Image.new("RGB", (1, 1))]) # hack - ensures that the model is available
images = generate_images_internal(
context, context,
req, req,
task_data, task_data,
models_data, models_data,
output_format,
data_queue, data_queue,
task_temp_images, task_temp_images,
step_callback, step_callback,
task_data.stream_image_progress, task_data.stream_image_progress,
task_data.stream_image_progress_interval, task_data.stream_image_progress_interval,
) )
user_stopped = isinstance(task.error, StopAsyncIteration)
gc(context)
filters, filter_params = task_data.filters, task_data.filter_params filters, filter_params = task_data.filters, task_data.filter_params
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images if len(filters) > 0 and not user_stopped:
filtered_images = backend.filter_images(context, images, filters, filter_params, input_type="base64")
else:
filtered_images = images
if task_data.block_nsfw:
filtered_images = filter_nsfw(filtered_images)
if save_data.save_to_disk_path is not None: if save_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data) images_pil = [base64_str_to_img(img) for img in images]
filtered_images_pil = [base64_str_to_img(img) for img in filtered_images]
save_images_to_disk(images_pil, filtered_images_pil, req, task_data, models_data, output_format, save_data)
seeds = [*range(req.seed, req.seed + len(images))] seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images: if task_data.show_only_filtered_image or filtered_images is images:
@ -226,97 +221,43 @@ def generate_images_internal(
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: RenderTaskData, task_data: RenderTaskData,
models_data: ModelsData, models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
stream_image_progress: bool, stream_image_progress: bool,
stream_image_progress_interval: int, stream_image_progress_interval: int,
): ):
context.temp_images.clear() from easydiffusion.backend_manager import backend
callback = make_step_callback( callback = make_step_callback(context, req, task_data, data_queue, task_temp_images, step_callback)
context,
req,
task_data,
data_queue,
task_temp_images,
step_callback,
stream_image_progress,
stream_image_progress_interval,
)
try:
if req.init_image is not None and not context.test_diffusers:
req.sampler_name = "ddim"
req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8 req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8
if req.control_image and task_data.control_filter_to_apply: if req.control_image and task_data.control_filter_to_apply:
req.control_image = get_image(req.control_image) req.controlnet_filter = task_data.control_filter_to_apply
req.control_image = resize_img(req.control_image.convert("RGB"), req.width, req.height, clamp_to_8=True)
req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0]
if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0: if req.init_image is not None and int(req.num_inference_steps * req.prompt_strength) == 0:
req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1 req.prompt_strength = 1 / req.num_inference_steps if req.num_inference_steps > 0 else 1
if context.test_diffusers: backend.set_options(
pipe = context.models["stable-diffusion"]["default"] context,
if hasattr(pipe.unet, "_allocate_trt_buffers_backup"): output_format=output_format.output_format,
setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup) output_quality=output_format.output_quality,
delattr(pipe.unet, "_allocate_trt_buffers_backup") output_lossless=output_format.output_lossless,
vae_tiling=task_data.enable_vae_tiling,
stream_image_progress=stream_image_progress,
stream_image_progress_interval=stream_image_progress_interval,
clip_skip=2 if task_data.clip_skip else 1,
)
if hasattr(pipe.unet, "_allocate_trt_buffers"): images = backend.generate_images(context, callback=callback, output_type="base64", **req.dict())
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
if convert_to_trt:
pipe.unet.forward = pipe.unet._trt_forward
# pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward
log.info(f"Setting unet.forward to TensorRT")
else:
log.info(f"Not using TensorRT for unet.forward")
pipe.unet.forward = pipe.unet._non_trt_forward
# pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward
setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers)
delattr(pipe.unet, "_allocate_trt_buffers")
if task_data.enable_vae_tiling: return images
if hasattr(pipe, "enable_vae_tiling"):
pipe.enable_vae_tiling()
else:
if hasattr(pipe, "disable_vae_tiling"):
pipe.disable_vae_tiling()
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if context.partial_x_samples is not None:
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
else:
images = latent_samples_to_images(context, context.partial_x_samples)
finally:
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
if not context.test_diffusers:
del context.partial_x_samples
context.partial_x_samples = None
return images, user_stopped
def construct_response(images: list, seeds: list, output_format: OutputFormatData): def construct_response(images: list, seeds: list, output_format: OutputFormatData):
return [ return [ResponseImage(data=img, seed=seed) for img, seed in zip(images, seeds)]
ResponseImage(
data=img_to_base64_str(
img,
output_format.output_format,
output_format.output_quality,
output_format.output_lossless,
),
seed=seed,
)
for img, seed in zip(images, seeds)
]
def make_step_callback( def make_step_callback(
@ -326,53 +267,44 @@ def make_step_callback(
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
): ):
from easydiffusion.backend_manager import backend
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1 last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list): def update_temp_img(images, task_temp_images: list):
partial_images = [] partial_images = []
if context.test_diffusers: if images is None:
images = diffusers_latent_samples_to_images(context, x_samples) return []
else:
images = latent_samples_to_images(context, x_samples)
if task_data.block_nsfw: if task_data.block_nsfw:
images = filter_images(context, images, "nsfw_checker") images = filter_nsfw(images, print_log=False)
for i, img in enumerate(images): for i, img in enumerate(images):
img = img.convert("RGB")
img = resize_img(img, req.width, req.height)
buf = img_to_buffer(img, output_format="JPEG") buf = img_to_buffer(img, output_format="JPEG")
context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf task_temp_images[i] = buf
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"}) partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
del images del images
return partial_images return partial_images
def on_image_step(x_samples, i, *args): def on_image_step(images, i, *args):
nonlocal last_callback_time nonlocal last_callback_time
if context.test_diffusers:
context.partial_x_samples = (x_samples, args[0])
else:
context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time() last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps} progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: if images is not None:
progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images) progress["output"] = update_temp_img(images, task_temp_images)
data_queue.put(json.dumps(progress)) data_queue.put(json.dumps(progress))
step_callback() step_callback()
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step return on_image_step

View File

@ -14,16 +14,19 @@ class GenerateImageRequest(BaseModel):
num_outputs: int = 1 num_outputs: int = 1
num_inference_steps: int = 50 num_inference_steps: int = 50
guidance_scale: float = 7.5 guidance_scale: float = 7.5
distilled_guidance_scale: float = 3.5
init_image: Any = None init_image: Any = None
init_image_mask: Any = None init_image_mask: Any = None
control_image: Any = None control_image: Any = None
control_alpha: Union[float, List[float]] = None control_alpha: Union[float, List[float]] = None
controlnet_filter: str = None
prompt_strength: float = 0.8 prompt_strength: float = 0.8
preserve_init_image_color_profile = False preserve_init_image_color_profile: bool = False
strict_mask_border = False strict_mask_border: bool = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
scheduler_name: str = None
hypernetwork_strength: float = 0 hypernetwork_strength: float = 0
lora_alpha: Union[float, List[float]] = 0 lora_alpha: Union[float, List[float]] = 0
tiling: str = None # None, "x", "y", "xy" tiling: str = None # None, "x", "y", "xy"
@ -100,7 +103,7 @@ class MergeRequest(BaseModel):
model1: str = None model1: str = None
ratio: float = None ratio: float = None
out_path: str = "mix" out_path: str = "mix"
use_fp16 = True use_fp16: bool = True
class Image: class Image:
@ -213,22 +216,19 @@ def convert_legacy_render_req_to_new(old_req: dict):
model_paths["controlnet"] = old_req.get("use_controlnet_model") model_paths["controlnet"] = old_req.get("use_controlnet_model")
model_paths["embeddings"] = old_req.get("use_embeddings_model") model_paths["embeddings"] = old_req.get("use_embeddings_model")
model_paths["gfpgan"] = old_req.get("use_face_correction", "") ## ensure that the model name is in the model path
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None for model_name in ("gfpgan", "codeformer"):
model_paths[model_name] = old_req.get("use_face_correction", "")
model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None
model_paths["codeformer"] = old_req.get("use_face_correction", "") for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"):
model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None model_paths[model_name] = old_req.get("use_upscale", "")
model_paths[model_name] = model_paths[model_name] if model_name in model_paths[model_name].lower() else None
model_paths["realesrgan"] = old_req.get("use_upscale", "")
model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None
model_paths["latent_upscaler"] = old_req.get("use_upscale", "")
model_paths["latent_upscaler"] = (
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
)
if "control_filter_to_apply" in old_req: if "control_filter_to_apply" in old_req:
filter_model = old_req["control_filter_to_apply"] filter_model = old_req["control_filter_to_apply"]
model_paths[filter_model] = filter_model model_paths[filter_model] = filter_model
old_req["control_filter_to_apply"] = convert_legacy_controlnet_filter_name(old_req["control_filter_to_apply"])
if old_req.get("block_nsfw"): if old_req.get("block_nsfw"):
model_paths["nsfw_checker"] = "nsfw_checker" model_paths["nsfw_checker"] = "nsfw_checker"
@ -244,8 +244,12 @@ def convert_legacy_render_req_to_new(old_req: dict):
} }
# move the filter params # move the filter params
if model_paths["realesrgan"]: for model_name in ("realesrgan", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"):
filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))} if model_paths[model_name]:
filter_params[model_name] = {
"upscaler": model_paths[model_name],
"scale": int(old_req.get("upscale_amount", 4)),
}
if model_paths["latent_upscaler"]: if model_paths["latent_upscaler"]:
filter_params["latent_upscaler"] = { filter_params["latent_upscaler"] = {
"prompt": old_req["prompt"], "prompt": old_req["prompt"],
@ -264,14 +268,31 @@ def convert_legacy_render_req_to_new(old_req: dict):
if old_req.get("block_nsfw"): if old_req.get("block_nsfw"):
filters.append("nsfw_checker") filters.append("nsfw_checker")
if model_paths["codeformer"]: for model_name in ("gfpgan", "codeformer"):
filters.append("codeformer") if model_paths[model_name]:
elif model_paths["gfpgan"]: filters.append(model_name)
filters.append("gfpgan") break
if model_paths["realesrgan"]: for model_name in ("realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"):
filters.append("realesrgan") if model_paths[model_name]:
elif model_paths["latent_upscaler"]: filters.append(model_name)
filters.append("latent_upscaler") break
return new_req return new_req
def convert_legacy_controlnet_filter_name(filter):
from easydiffusion.backend_manager import backend
if filter is None:
return None
controlnet_filter_names = backend.list_controlnet_filters()
def apply(f):
return f"controlnet_{f}" if f in controlnet_filter_names else f
if isinstance(filter, list):
return [apply(f) for f in filter]
return apply(filter)

View File

@ -7,6 +7,8 @@ from .save_utils import (
save_images_to_disk, save_images_to_disk,
get_printable_request, get_printable_request,
) )
from .nsfw_checker import filter_nsfw
def sha256sum(filename): def sha256sum(filename):
sha256 = hashlib.sha256() sha256 = hashlib.sha256()
@ -18,4 +20,3 @@ def sha256sum(filename):
sha256.update(data) sha256.update(data)
return sha256.hexdigest() return sha256.hexdigest()

View File

@ -0,0 +1,80 @@
# possibly move this to sdkit in the future
import os
# mirror of https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/model_quantized.onnx
NSFW_MODEL_URL = (
"https://github.com/easydiffusion/sdkit-test-data/releases/download/assets/vit-base-nsfw-detector-quantized.onnx"
)
MODEL_HASH_QUICK = "220123559305b1b07b7a0894c3471e34dccd090d71cdf337dd8012f9e40d6c28"
nsfw_check_model = None
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
global nsfw_check_model
from easydiffusion.model_manager import get_model_dirs
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
import onnxruntime as ort
from PIL import ImageFilter
import numpy as np
if nsfw_check_model is None:
model_dir = get_model_dirs("nsfw-checker")[0]
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(model_path) or hash_file_quick(model_path) != MODEL_HASH_QUICK:
download_file(NSFW_MODEL_URL, model_path)
nsfw_check_model = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
# Preprocess the input image
def preprocess_image(img):
img = img.convert("RGB")
# config based on based on https://huggingface.co/AdamCodd/vit-base-nsfw-detector/blob/main/onnx/preprocessor_config.json
# Resize the image
img = img.resize((384, 384))
# Normalize the image
img = np.array(img) / 255.0 # Scale pixel values to [0, 1]
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.5, 0.5, 0.5])
img = (img - mean) / std
# Transpose to match input shape (batch_size, channels, height, width)
img = np.transpose(img, (2, 0, 1)).astype(np.float32)
# Add batch dimension
img = np.expand_dims(img, axis=0)
return img
# Run inference
input_name = nsfw_check_model.get_inputs()[0].name
output_name = nsfw_check_model.get_outputs()[0].name
if print_log:
log.info("Running NSFW checker (onnx)")
results = []
for img in images:
is_base64 = isinstance(img, str)
input_img = base64_str_to_img(img) if is_base64 else img
result = nsfw_check_model.run([output_name], {input_name: preprocess_image(input_img)})
is_nsfw = [np.argmax(arr) == 1 for arr in result][0]
if is_nsfw:
output_img = input_img.filter(ImageFilter.GaussianBlur(blur_radius))
output_img = img_to_base64_str(output_img) if is_base64 else output_img
else:
output_img = img
results.append(output_img)
return results

View File

@ -34,10 +34,12 @@ TASK_TEXT_MAPPING = {
"control_alpha": "ControlNet Strength", "control_alpha": "ControlNet Strength",
"use_vae_model": "VAE model", "use_vae_model": "VAE model",
"sampler_name": "Sampler", "sampler_name": "Sampler",
"scheduler_name": "Scheduler",
"width": "Width", "width": "Width",
"height": "Height", "height": "Height",
"num_inference_steps": "Steps", "num_inference_steps": "Steps",
"guidance_scale": "Guidance Scale", "guidance_scale": "Guidance Scale",
"distilled_guidance_scale": "Distilled Guidance",
"prompt_strength": "Prompt Strength", "prompt_strength": "Prompt Strength",
"use_lora_model": "LoRA model", "use_lora_model": "LoRA model",
"lora_alpha": "LoRA Strength", "lora_alpha": "LoRA Strength",
@ -247,7 +249,7 @@ def get_printable_request(
task_data_metadata.update(save_data.dict()) task_data_metadata.update(save_data.dict())
app_config = app.getConfig() app_config = app.getConfig()
using_diffusers = app_config.get("use_v3_engine", True) using_diffusers = app_config.get("backend", "ed_diffusers") in ("ed_diffusers", "webui")
# Save the metadata in the order defined in TASK_TEXT_MAPPING # Save the metadata in the order defined in TASK_TEXT_MAPPING
metadata = {} metadata = {}

View File

@ -35,7 +35,13 @@
<h1> <h1>
<img id="logo_img" src="/media/images/icon-512x512.png" > <img id="logo_img" src="/media/images/icon-512x512.png" >
Easy Diffusion Easy Diffusion
<small><span id="version">v3.0.9</span> <span id="updateBranchLabel"></span></small> <small>
<span id="version">
<span class="gated-feature" data-feature-keys="backend_ed_classic backend_ed_diffusers">v3.0.10</span>
<span class="gated-feature" data-feature-keys="backend_webui">v3.5.0</span>
</span> <span id="updateBranchLabel"></span>
<div id="engine-logo" class="gated-feature" data-feature-keys="backend_webui">(Powered by <a id="backend-url" href="https://github.com/lllyasviel/stable-diffusion-webui-forge" target="_blank">Stable Diffusion WebUI Forge</a>)</div>
</small>
</h1> </h1>
</div> </div>
<div id="server-status"> <div id="server-status">
@ -73,7 +79,7 @@
</div> </div>
<div id="prompt-toolbar-right" class="toolbar-right"> <div id="prompt-toolbar-right" class="toolbar-right">
<button id="image-modifier-dropdown" class="tertiaryButton smallButton">+ Image Modifiers</button> <button id="image-modifier-dropdown" class="tertiaryButton smallButton">+ Image Modifiers</button>
<button id="embeddings-button" class="tertiaryButton smallButton displayNone">+ Embedding</button> <button id="embeddings-button" class="tertiaryButton smallButton gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">+ Embedding</button>
</div> </div>
</div> </div>
<textarea id="prompt" class="col-free">a photograph of an astronaut riding a horse</textarea> <textarea id="prompt" class="col-free">a photograph of an astronaut riding a horse</textarea>
@ -83,7 +89,7 @@
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Writing-prompts#negative-prompts" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top">Click to learn more about Negative Prompts</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/Writing-prompts#negative-prompts" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top">Click to learn more about Negative Prompts</span></i></a>
<small>(optional)</small> <small>(optional)</small>
</label> </label>
<button id="negative-embeddings-button" class="tertiaryButton smallButton displayNone">+ Negative Embedding</button> <button id="negative-embeddings-button" class="tertiaryButton smallButton gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">+ Negative Embedding</button>
<div class="collapsible-content"> <div class="collapsible-content">
<textarea id="negative_prompt" name="negative_prompt" placeholder="list the things to remove from the image (e.g. fog, green)"></textarea> <textarea id="negative_prompt" name="negative_prompt" placeholder="list the things to remove from the image (e.g. fog, green)"></textarea>
</div> </div>
@ -174,14 +180,14 @@
<!-- <label><small>Takes upto 20 mins the first time</small></label> --> <!-- <label><small>Takes upto 20 mins the first time</small></label> -->
</td> </td>
</tr> </tr>
<tr class="pl-5 displayNone" id="clip_skip_config"> <tr class="pl-5 gated-feature" id="clip_skip_config" data-feature-keys="backend_ed_diffusers backend_webui">
<td><label for="clip_skip">Clip Skip:</label></td> <td><label for="clip_skip">Clip Skip:</label></td>
<td class="diffusers-restart-needed"> <td class="diffusers-restart-needed">
<input id="clip_skip" name="clip_skip" type="checkbox"> <input id="clip_skip" name="clip_skip" type="checkbox">
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Clip-Skip" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about Clip Skip</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/Clip-Skip" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about Clip Skip</span></i></a>
</td> </td>
</tr> </tr>
<tr id="controlnet_model_container" class="pl-5"> <tr id="controlnet_model_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">
<td><label for="controlnet_model">ControlNet Image:</label></td> <td><label for="controlnet_model">ControlNet Image:</label></td>
<td class="diffusers-restart-needed"> <td class="diffusers-restart-needed">
<div id="control_image_wrapper" class="preview_image_wrapper"> <div id="control_image_wrapper" class="preview_image_wrapper">
@ -201,40 +207,92 @@
<option value="openpose_faceonly">OpenPose face-only</option> <option value="openpose_faceonly">OpenPose face-only</option>
<option value="openpose_hand">OpenPose hand</option> <option value="openpose_hand">OpenPose hand</option>
<option value="openpose_full">OpenPose full</option> <option value="openpose_full">OpenPose full</option>
<option value="animal_openpose" class="gated-feature" data-feature-keys="backend_webui">animal_openpose</option>
<option value="densepose_parula (black bg & blue torso)" class="gated-feature" data-feature-keys="backend_webui">densepose_parula (black bg & blue torso)</option>
<option value="densepose (pruple bg & purple torso)" class="gated-feature" data-feature-keys="backend_webui">densepose (pruple bg & purple torso)</option>
<option value="dw_openpose_full" class="gated-feature" data-feature-keys="backend_webui">dw_openpose_full</option>
<option value="mediapipe_face" class="gated-feature" data-feature-keys="backend_webui">mediapipe_face</option>
<option value="instant_id_face_keypoints" class="gated-feature" data-feature-keys="backend_webui">instant_id_face_keypoints</option>
<option value="InsightFace+CLIP-H (IPAdapter)" class="gated-feature" data-feature-keys="backend_webui">InsightFace+CLIP-H (IPAdapter)</option>
<option value="InsightFace (InstantID)" class="gated-feature" data-feature-keys="backend_webui">InsightFace (InstantID)</option>
</optgroup> </optgroup>
<optgroup label="Outline"> <optgroup label="Outline">
<option value="canny">Canny (*)</option> <option value="canny">Canny (*)</option>
<option value="mlsd">Straight lines</option> <option value="mlsd">Straight lines</option>
<option value="scribble_hed">Scribble hed (*)</option> <option value="scribble_hed">Scribble hed (*)</option>
<option value="scribble_hedsafe">Scribble hedsafe</option> <option value="scribble_hedsafe" class="gated-feature" data-feature-keys="backend_diffusers">Scribble hedsafe</option>
<option value="scribble_pidinet">Scribble pidinet</option> <option value="scribble_pidinet">Scribble pidinet</option>
<option value="scribble_pidsafe">Scribble pidsafe</option> <option value="scribble_pidsafe" class="gated-feature" data-feature-keys="backend_diffusers">Scribble pidsafe</option>
<option value="scribble_xdog" class="gated-feature" data-feature-keys="backend_webui">scribble_xdog</option>
<option value="softedge_hed">Softedge hed</option> <option value="softedge_hed">Softedge hed</option>
<option value="softedge_hedsafe">Softedge hedsafe</option> <option value="softedge_hedsafe">Softedge hedsafe</option>
<option value="softedge_pidinet">Softedge pidinet</option> <option value="softedge_pidinet">Softedge pidinet</option>
<option value="softedge_pidsafe">Softedge pidsafe</option> <option value="softedge_pidsafe">Softedge pidsafe</option>
<option value="softedge_teed" class="gated-feature" data-feature-keys="backend_webui">softedge_teed</option>
</optgroup> </optgroup>
<optgroup label="Depth"> <optgroup label="Depth">
<option value="normal_bae">Normal bae (*)</option> <option value="normal_bae">Normal bae (*)</option>
<option value="depth_midas">Depth midas</option> <option value="depth_midas">Depth midas</option>
<option value="normal_midas" class="gated-feature" data-feature-keys="backend_webui">normal_midas</option>
<option value="depth_zoe">Depth zoe</option> <option value="depth_zoe">Depth zoe</option>
<option value="depth_leres">Depth leres</option> <option value="depth_leres">Depth leres</option>
<option value="depth_leres++">Depth leres++</option> <option value="depth_leres++">Depth leres++</option>
<option value="depth_anything_v2" class="gated-feature" data-feature-keys="backend_webui">depth_anything_v2</option>
<option value="depth_anything" class="gated-feature" data-feature-keys="backend_webui">depth_anything</option>
<option value="depth_hand_refiner" class="gated-feature" data-feature-keys="backend_webui">depth_hand_refiner</option>
<option value="depth_marigold" class="gated-feature" data-feature-keys="backend_webui">depth_marigold</option>
</optgroup> </optgroup>
<optgroup label="Line art"> <optgroup label="Line art">
<option value="lineart_coarse">Lineart coarse</option> <option value="lineart_coarse">Lineart coarse</option>
<option value="lineart_realistic">Lineart realistic</option> <option value="lineart_realistic">Lineart realistic</option>
<option value="lineart_anime">Lineart anime</option> <option value="lineart_anime">Lineart anime</option>
<option value="lineart_standard (from white bg & black line)" class="gated-feature" data-feature-keys="backend_webui">lineart_standard (from white bg & black line)</option>
<option value="lineart_anime_denoise" class="gated-feature" data-feature-keys="backend_webui">lineart_anime_denoise</option>
</optgroup>
<optgroup label="Reference" class="gated-feature" data-feature-keys="backend_webui">
<option value="reference_adain">reference_adain</option>
<option value="reference_only">reference_only</option>
<option value="reference_adain+attn">reference_adain+attn</option>
</optgroup>
<optgroup label="Tile" class="gated-feature" data-feature-keys="backend_webui">
<option value="tile_colorfix">tile_colorfix</option>
<option value="tile_resample">tile_resample</option>
<option value="tile_colorfix+sharp">tile_colorfix+sharp</option>
</optgroup>
<optgroup label="CLIP (IPAdapter)" class="gated-feature" data-feature-keys="backend_webui">
<option value="CLIP-ViT-H (IPAdapter)">CLIP-ViT-H (IPAdapter)</option>
<option value="CLIP-G (Revision)">CLIP-G (Revision)</option>
<option value="CLIP-G (Revision ignore prompt)">CLIP-G (Revision ignore prompt)</option>
<option value="CLIP-ViT-bigG (IPAdapter)">CLIP-ViT-bigG (IPAdapter)</option>
<option value="InsightFace+CLIP-H (IPAdapter)">InsightFace+CLIP-H (IPAdapter)</option>
</optgroup>
<optgroup label="Inpaint" class="gated-feature" data-feature-keys="backend_webui">
<option value="inpaint_only">inpaint_only</option>
<option value="inpaint_only+lama">inpaint_only+lama</option>
<option value="inpaint_global_harmonious">inpaint_global_harmonious</option>
</optgroup>
<optgroup label="Segment" class="gated-feature" data-feature-keys="backend_webui">
<option value="seg_ufade20k">seg_ufade20k</option>
<option value="seg_ofade20k">seg_ofade20k</option>
<option value="seg_anime_face">seg_anime_face</option>
<option value="seg_ofcoco">seg_ofcoco</option>
</optgroup> </optgroup>
<optgroup label="Misc"> <optgroup label="Misc">
<option value="shuffle">Shuffle</option> <option value="shuffle">Shuffle</option>
<option value="segment">Segment</option> <option value="segment" class="gated-feature" data-feature-keys="backend_diffusers">Segment</option>
<option value="invert (from white bg & black line)" class="gated-feature" data-feature-keys="backend_webui">invert (from white bg & black line)</option>
<option value="threshold" class="gated-feature" data-feature-keys="backend_webui">threshold</option>
<option value="t2ia_sketch_pidi" class="gated-feature" data-feature-keys="backend_webui">t2ia_sketch_pidi</option>
<option value="t2ia_color_grid" class="gated-feature" data-feature-keys="backend_webui">t2ia_color_grid</option>
<option value="recolor_intensity" class="gated-feature" data-feature-keys="backend_webui">recolor_intensity</option>
<option value="recolor_luminance" class="gated-feature" data-feature-keys="backend_webui">recolor_luminance</option>
<option value="blur_gaussian" class="gated-feature" data-feature-keys="backend_webui">blur_gaussian</option>
</optgroup> </optgroup>
</select> </select>
<br/> <br/>
<label for="controlnet_model"><small>Model:</small></label> <input id="controlnet_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" /> <label for="controlnet_model"><small>Model:</small></label> <input id="controlnet_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
<br/> <!-- <br/>
<label><small>Will download the necessary models, the first time.</small></label> <label><small>Will download the necessary models, the first time.</small></label> -->
<br/> <br/>
<label for="controlnet_alpha_slider"><small>Strength:</small></label> <input id="controlnet_alpha_slider" name="controlnet_alpha_slider" class="editor-slider" value="10" type="range" min="0" max="10"> <input id="controlnet_alpha" name="controlnet_alpha" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"> <label for="controlnet_alpha_slider"><small>Strength:</small></label> <input id="controlnet_alpha_slider" name="controlnet_alpha_slider" class="editor-slider" value="10" type="range" min="0" max="10"> <input id="controlnet_alpha" name="controlnet_alpha" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal">
</div> </div>
@ -248,27 +306,59 @@
<select id="sampler_name" name="sampler_name"> <select id="sampler_name" name="sampler_name">
<option value="plms">PLMS</option> <option value="plms">PLMS</option>
<option value="ddim">DDIM</option> <option value="ddim">DDIM</option>
<option value="ddim_cfgpp" class="gated-feature" data-feature-keys="backend_webui">DDIM CFG++</option>
<option value="heun">Heun</option> <option value="heun">Heun</option>
<option value="euler">Euler</option> <option value="euler">Euler</option>
<option value="euler_a" selected>Euler Ancestral</option> <option value="euler_a" selected>Euler Ancestral</option>
<option value="dpm2">DPM2</option> <option value="dpm2">DPM2</option>
<option value="dpm2_a">DPM2 Ancestral</option> <option value="dpm2_a">DPM2 Ancestral</option>
<option value="dpm_fast" class="gated-feature" data-feature-keys="backend_webui">DPM Fast</option>
<option value="dpm_adaptive" class="gated-feature" data-feature-keys="backend_ed_classic backend_webui">DPM Adaptive</option>
<option value="lms">LMS</option> <option value="lms">LMS</option>
<option value="dpm_solver_stability">DPM Solver (Stability AI)</option> <option value="dpm_solver_stability" class="gated-feature" data-feature-keys="backend_ed_classic backend_ed_diffusers">DPM Solver (Stability AI)</option>
<option value="dpmpp_2s_a">DPM++ 2s Ancestral (Karras)</option> <option value="dpmpp_2s_a">DPM++ 2s Ancestral (Karras)</option>
<option value="dpmpp_2m">DPM++ 2m (Karras)</option> <option value="dpmpp_2m">DPM++ 2m (Karras)</option>
<option value="dpmpp_2m_sde" class="diffusers-only">DPM++ 2m SDE (Karras)</option> <option value="dpmpp_2m_sde" class="gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">DPM++ 2m SDE</option>
<option value="dpmpp_2m_sde_heun" class="gated-feature" data-feature-keys="backend_webui">DPM++ 2m SDE Heun</option>
<option value="dpmpp_3m_sde" class="gated-feature" data-feature-keys="backend_webui">DPM++ 3M SDE</option>
<option value="dpmpp_sde">DPM++ SDE (Karras)</option> <option value="dpmpp_sde">DPM++ SDE (Karras)</option>
<option value="dpm_adaptive" class="k_diffusion-only">DPM Adaptive (Karras)</option> <option value="restart" class="gated-feature" data-feature-keys="backend_webui">Restart</option>
<option value="ddpm" class="diffusers-only">DDPM</option> <option value="heun_pp2" class="gated-feature" data-feature-keys="backend_webui">Heun PP2</option>
<option value="deis" class="diffusers-only">DEIS</option> <option value="ipndm" class="gated-feature" data-feature-keys="backend_webui">IPNDM</option>
<option value="unipc_snr" class="k_diffusion-only">UniPC SNR</option> <option value="ipndm_v" class="gated-feature" data-feature-keys="backend_webui">IPNDM_V</option>
<option value="unipc_tu">UniPC TU</option> <option value="ddpm" class="gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">DDPM</option>
<option value="unipc_snr_2" class="k_diffusion-only">UniPC SNR 2</option> <option value="deis" class="gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">DEIS</option>
<option value="unipc_tu_2" class="k_diffusion-only">UniPC TU 2</option> <option value="lcm" class="gated-feature" data-feature-keys="backend_webui">LCM</option>
<option value="unipc_tq" class="k_diffusion-only">UniPC TQ</option> <option value="forge_flux_realistic" class="gated-feature" data-feature-keys="backend_webui">[Forge] Flux Realistic</option>
<option value="forge_flux_realistic_slow" class="gated-feature" data-feature-keys="backend_webui">[Forge] Flux Realistic (Slow)</option>
<option value="unipc_snr" class="gated-feature" data-feature-keys="backend_ed_classic">UniPC SNR</option>
<option value="unipc_tu" class="gated-feature" data-feature-keys="backend_ed_classic backend_ed_diffusers">UniPC TU</option>
<option value="unipc_snr_2" class="gated-feature" data-feature-keys="backend_ed_classic">UniPC SNR 2</option>
<option value="unipc_tu_2" class="gated-feature" data-feature-keys="backend_ed_classic">UniPC TU 2</option>
<option value="unipc_tq" class="gated-feature" data-feature-keys="backend_ed_classic">UniPC TQ</option>
</select>
<a href="https://github.com/easydiffusion/easydiffusion/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
</td></tr>
<tr class="pl-5 warning-label displayNone" id="fluxSamplerWarning"><td></td><td>Please avoid 'Euler Ancestral' with Flux!</td></tr>
<tr id="schedulerSelection" class="pl-5 gated-feature" data-feature-keys="backend_webui"><td><label for="scheduler_name">Scheduler:</label></td><td>
<select id="scheduler_name" name="scheduler_name">
<option value="automatic">Automatic</option>
<option value="uniform">Uniform</option>
<option value="karras">Karras</option>
<option value="exponential">Exponential</option>
<option value="polyexponential">Polyexponential</option>
<option value="sgm_uniform">SGM Uniform</option>
<option value="kl_optimal">KL Optimal</option>
<option value="align_your_steps">Align Your Steps</option>
<option value="simple" selected>Simple</option>
<option value="normal">Normal</option>
<option value="ddim">DDIM</option>
<option value="beta">Beta</option>
<option value="turbo">Turbo</option>
<option value="align_your_steps_GITS">Align Your Steps GITS</option>
<option value="align_your_steps_11">Align Your Steps 11</option>
<option value="align_your_steps_32">Align Your Steps 32</option>
</select> </select>
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
</td></tr> </td></tr>
<tr class="pl-5"><td><label>Image Size: </label></td><td id="image-size-options"> <tr class="pl-5"><td><label>Image Size: </label></td><td id="image-size-options">
<select id="width" name="width" value="512"> <select id="width" name="width" value="512">
@ -344,12 +434,14 @@
</div> </div>
</div> </div>
</div> </div>
<div id="small_image_warning" class="displayNone">Small image sizes can cause bad image quality</div> <div id="small_image_warning" class="displayNone warning-label">Small image sizes can cause bad image quality</div>
</td></tr> </td></tr>
<tr class="pl-5"><td><label for="num_inference_steps">Inference Steps:</label></td><td> <input id="num_inference_steps" name="num_inference_steps" type="number" min="1" step="1" style="width: 42pt" value="25" onkeypress="preventNonNumericalInput(event)" inputmode="numeric"></td></tr> <tr class="pl-5"><td><label for="num_inference_steps">Inference Steps:</label></td><td> <input id="num_inference_steps" name="num_inference_steps" type="number" min="1" step="1" style="width: 42pt" value="25" onkeypress="preventNonNumericalInput(event)" inputmode="numeric"></td></tr>
<tr class="pl-5"><td><label for="guidance_scale_slider">Guidance Scale:</label></td><td> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="11" max="500"> <input id="guidance_scale" name="guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr> <tr class="pl-5"><td><label for="guidance_scale_slider">Guidance Scale:</label></td><td> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="11" max="500"> <input id="guidance_scale" name="guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
<tr class="pl-5 displayNone warning-label" id="guidanceWarning"><td></td><td id="guidanceWarningText"></td></tr>
<tr id="prompt_strength_container" class="pl-5"><td><label for="prompt_strength_slider">Prompt Strength:</label></td><td> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td></tr> <tr id="prompt_strength_container" class="pl-5"><td><label for="prompt_strength_slider">Prompt Strength:</label></td><td> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td></tr>
<tr id="lora_model_container" class="pl-5"> <tr id="distilled_guidance_scale_container" class="pl-5 displayNone"><td><label for="distilled_guidance_scale_slider">Distilled Guidance:</label></td><td> <input id="distilled_guidance_scale_slider" name="distilled_guidance_scale_slider" class="editor-slider" value="35" type="range" min="11" max="500"> <input id="distilled_guidance_scale" name="distilled_guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
<tr id="lora_model_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">
<td> <td>
<label for="lora_model">LoRA:</label> <label for="lora_model">LoRA:</label>
</td> </td>
@ -357,14 +449,14 @@
<div id="lora_model" data-path=""></div> <div id="lora_model" data-path=""></div>
</td> </td>
</tr> </tr>
<tr id="hypernetwork_model_container" class="pl-5"><td><label for="hypernetwork_model">Hypernetwork:</label></td><td> <tr id="hypernetwork_model_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_classic"><td><label for="hypernetwork_model">Hypernetwork:</label></td><td>
<input id="hypernetwork_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" /> <input id="hypernetwork_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
</td></tr> </td></tr>
<tr id="hypernetwork_strength_container" class="pl-5"> <tr id="hypernetwork_strength_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_classic">
<td><label for="hypernetwork_strength_slider">Hypernetwork Strength:</label></td> <td><label for="hypernetwork_strength_slider">Hypernetwork Strength:</label></td>
<td> <input id="hypernetwork_strength_slider" name="hypernetwork_strength_slider" class="editor-slider" value="100" type="range" min="0" max="100"> <input id="hypernetwork_strength" name="hypernetwork_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td> <td> <input id="hypernetwork_strength_slider" name="hypernetwork_strength_slider" class="editor-slider" value="100" type="range" min="0" max="100"> <input id="hypernetwork_strength" name="hypernetwork_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td>
</tr> </tr>
<tr id="tiling_container" class="pl-5"> <tr id="tiling_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers">
<td><label for="tiling">Seamless Tiling:</label></td> <td><label for="tiling">Seamless Tiling:</label></td>
<td class="diffusers-restart-needed"> <td class="diffusers-restart-needed">
<select id="tiling" name="tiling"> <select id="tiling" name="tiling">
@ -389,7 +481,7 @@
<tr class="pl-5" id="output_quality_row"><td><label for="output_quality">Image Quality:</label></td><td> <tr class="pl-5" id="output_quality_row"><td><label for="output_quality">Image Quality:</label></td><td>
<input id="output_quality_slider" name="output_quality" class="editor-slider" value="75" type="range" min="10" max="95"> <input id="output_quality" name="output_quality" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="numeric"> <input id="output_quality_slider" name="output_quality" class="editor-slider" value="75" type="range" min="10" max="95"> <input id="output_quality" name="output_quality" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="numeric">
</td></tr> </td></tr>
<tr class="pl-5"> <tr class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers">
<td><label for="tiling">Enable VAE Tiling:</label></td> <td><label for="tiling">Enable VAE Tiling:</label></td>
<td class="diffusers-restart-needed"> <td class="diffusers-restart-needed">
<input id="enable_vae_tiling" name="enable_vae_tiling" type="checkbox" checked> <input id="enable_vae_tiling" name="enable_vae_tiling" type="checkbox" checked>
@ -405,7 +497,7 @@
<input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes</label> <div style="display:inline-block;"><input id="gfpgan_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" /></div> <input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes</label> <div style="display:inline-block;"><input id="gfpgan_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" /></div>
<table id="codeformer_settings" class="displayNone sub-settings"> <table id="codeformer_settings" class="displayNone sub-settings">
<tr class="pl-5"><td><label for="codeformer_fidelity_slider">Strength:</label></td><td><input id="codeformer_fidelity_slider" name="codeformer_fidelity_slider" class="editor-slider" value="5" type="range" min="0" max="10"> <input id="codeformer_fidelity" name="codeformer_fidelity" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr> <tr class="pl-5"><td><label for="codeformer_fidelity_slider">Strength:</label></td><td><input id="codeformer_fidelity_slider" name="codeformer_fidelity_slider" class="editor-slider" value="5" type="range" min="0" max="10"> <input id="codeformer_fidelity" name="codeformer_fidelity" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
<tr class="pl-5"><td><label for="codeformer_upscale_faces">Upscale Faces:</label></td><td><input id="codeformer_upscale_faces" name="codeformer_upscale_faces" type="checkbox" checked> <label><small>(improves the resolution of faces)</small></label></td></tr> <tr class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers"><td><label for="codeformer_upscale_faces">Upscale Faces:</label></td><td><input id="codeformer_upscale_faces" name="codeformer_upscale_faces" type="checkbox" checked> <label><small>(improves the resolution of faces)</small></label></td></tr>
</table> </table>
</li> </li>
<li class="pl-5"> <li class="pl-5">
@ -418,7 +510,13 @@
<select id="upscale_model" name="upscale_model"> <select id="upscale_model" name="upscale_model">
<option value="RealESRGAN_x4plus" selected>RealESRGAN_x4plus</option> <option value="RealESRGAN_x4plus" selected>RealESRGAN_x4plus</option>
<option value="RealESRGAN_x4plus_anime_6B">RealESRGAN_x4plus_anime_6B</option> <option value="RealESRGAN_x4plus_anime_6B">RealESRGAN_x4plus_anime_6B</option>
<option value="latent_upscaler">Latent Upscaler 2x</option> <option value="ESRGAN_4x" class="pl-5 gated-feature" data-feature-keys="backend_webui">ESRGAN_4x</option>
<option value="Lanczos" class="pl-5 gated-feature" data-feature-keys="backend_webui">Lanczos</option>
<option value="Nearest" class="pl-5 gated-feature" data-feature-keys="backend_webui">Nearest</option>
<option value="ScuNET" class="pl-5 gated-feature" data-feature-keys="backend_webui">ScuNET GAN</option>
<option value="ScuNET PSNR" class="pl-5 gated-feature" data-feature-keys="backend_webui">ScuNET PSNR</option>
<option value="SwinIR_4x" class="pl-5 gated-feature" data-feature-keys="backend_webui">SwinIR 4x</option>
<option value="latent_upscaler" class="pl-5 gated-feature" data-feature-keys="backend_ed_classic backend_ed_diffusers">Latent Upscaler 2x</option>
</select> </select>
<table id="latent_upscaler_settings" class="displayNone sub-settings"> <table id="latent_upscaler_settings" class="displayNone sub-settings">
<tr class="pl-5"><td><label for="latent_upscaler_steps_slider">Upscaling Steps:</label></td><td><input id="latent_upscaler_steps_slider" name="latent_upscaler_steps_slider" class="editor-slider" value="10" type="range" min="1" max="50"> <input id="latent_upscaler_steps" name="latent_upscaler_steps" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="numeric"></td></tr> <tr class="pl-5"><td><label for="latent_upscaler_steps_slider">Upscaling Steps:</label></td><td><input id="latent_upscaler_steps_slider" name="latent_upscaler_steps_slider" class="editor-slider" value="10" type="range" min="1" max="50"> <input id="latent_upscaler_steps" name="latent_upscaler_steps" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="numeric"></td></tr>
@ -825,7 +923,8 @@
<p>This license of this software forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, <br/>spread misinformation and target vulnerable groups. For the full list of restrictions please read <a href="https://github.com/easydiffusion/easydiffusion/blob/main/LICENSE" target="_blank">the license</a>.</p> <p>This license of this software forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, <br/>spread misinformation and target vulnerable groups. For the full list of restrictions please read <a href="https://github.com/easydiffusion/easydiffusion/blob/main/LICENSE" target="_blank">the license</a>.</p>
<p>By using this software, you consent to the terms and conditions of the license.</p> <p>By using this software, you consent to the terms and conditions of the license.</p>
</div> </div>
<input id="test_diffusers" type="checkbox" style="display: none" checked /> <input id="test_diffusers" type="checkbox" style="display: none" checked /> <!-- for backwards compatibility -->
<input id="use_v3_engine" type="checkbox" style="display: none" checked /> <!-- for backwards compatibility -->
</div> </div>
</div> </div>
</body> </body>

View File

@ -9,6 +9,3 @@ server.init()
model_manager.init() model_manager.init()
app.init_render_threads() app.init_render_threads()
bucket_manager.init() bucket_manager.init()
# start the browser ui
app.open_browser()

View File

@ -79,6 +79,7 @@
} }
.parameters-table .fa-fire, .parameters-table .fa-fire,
.parameters-table .fa-bolt { .parameters-table .fa-bolt,
.parameters-table .fa-robot {
color: #F7630C; color: #F7630C;
} }

View File

@ -36,6 +36,15 @@ code {
transform: translateY(4px); transform: translateY(4px);
cursor: pointer; cursor: pointer;
} }
#engine-logo {
font-size: 8pt;
padding-left: 10pt;
color: var(--small-label-color);
}
#engine-logo a {
text-decoration: none;
/* color: var(--small-label-color); */
}
#prompt { #prompt {
width: 100%; width: 100%;
height: 65pt; height: 65pt;
@ -541,7 +550,7 @@ div.img-preview img {
position: relative; position: relative;
background: var(--background-color4); background: var(--background-color4);
display: flex; display: flex;
padding: 12px 0 0; padding: 6px 0 0;
} }
.tab .icon { .tab .icon {
padding-right: 4pt; padding-right: 4pt;
@ -657,6 +666,15 @@ div.img-preview img {
display: block; display: block;
} }
.gated-feature {
display: none;
}
.warning-label {
font-size: smaller;
color: var(--status-orange);
}
.display-settings { .display-settings {
float: right; float: right;
position: relative; position: relative;
@ -1459,11 +1477,6 @@ div.top-right {
margin-top: 6px; margin-top: 6px;
} }
#small_image_warning {
font-size: smaller;
color: var(--status-orange);
}
button#save-system-settings-btn { button#save-system-settings-btn {
padding: 4pt 8pt; padding: 4pt 8pt;
} }

View File

@ -16,10 +16,12 @@ const SETTINGS_IDS_LIST = [
"clip_skip", "clip_skip",
"vae_model", "vae_model",
"sampler_name", "sampler_name",
"scheduler_name",
"width", "width",
"height", "height",
"num_inference_steps", "num_inference_steps",
"guidance_scale", "guidance_scale",
"distilled_guidance_scale",
"prompt_strength", "prompt_strength",
"tiling", "tiling",
"output_format", "output_format",

View File

@ -131,6 +131,15 @@ const TASK_MAPPING = {
readUI: () => parseFloat(guidanceScaleField.value), readUI: () => parseFloat(guidanceScaleField.value),
parse: (val) => parseFloat(val), parse: (val) => parseFloat(val),
}, },
distilled_guidance_scale: {
name: "Distilled Guidance",
setUI: (distilled_guidance_scale) => {
distilledGuidanceScaleField.value = distilled_guidance_scale
updateDistilledGuidanceScaleSlider()
},
readUI: () => parseFloat(distilledGuidanceScaleField.value),
parse: (val) => parseFloat(val),
},
prompt_strength: { prompt_strength: {
name: "Prompt Strength", name: "Prompt Strength",
setUI: (prompt_strength) => { setUI: (prompt_strength) => {
@ -242,6 +251,14 @@ const TASK_MAPPING = {
readUI: () => samplerField.value, readUI: () => samplerField.value,
parse: (val) => val, parse: (val) => val,
}, },
scheduler_name: {
name: "Scheduler",
setUI: (scheduler_name) => {
schedulerField.value = scheduler_name
},
readUI: () => schedulerField.value,
parse: (val) => val,
},
use_stable_diffusion_model: { use_stable_diffusion_model: {
name: "Stable Diffusion model", name: "Stable Diffusion model",
setUI: (use_stable_diffusion_model) => { setUI: (use_stable_diffusion_model) => {
@ -590,11 +607,13 @@ const TASK_TEXT_MAPPING = {
seed: "Seed", seed: "Seed",
num_inference_steps: "Steps", num_inference_steps: "Steps",
guidance_scale: "Guidance Scale", guidance_scale: "Guidance Scale",
distilled_guidance_scale: "Distilled Guidance",
prompt_strength: "Prompt Strength", prompt_strength: "Prompt Strength",
use_face_correction: "Use Face Correction", use_face_correction: "Use Face Correction",
use_upscale: "Use Upscaling", use_upscale: "Use Upscaling",
upscale_amount: "Upscale By", upscale_amount: "Upscale By",
sampler_name: "Sampler", sampler_name: "Sampler",
scheduler_name: "Scheduler",
negative_prompt: "Negative Prompt", negative_prompt: "Negative Prompt",
use_stable_diffusion_model: "Stable Diffusion model", use_stable_diffusion_model: "Stable Diffusion model",
use_hypernetwork_model: "Hypernetwork model", use_hypernetwork_model: "Hypernetwork model",

View File

@ -12,8 +12,16 @@ const taskConfigSetup = {
seed: { value: ({ seed }) => seed, label: "Seed" }, seed: { value: ({ seed }) => seed, label: "Seed" },
dimensions: { value: ({ reqBody }) => `${reqBody?.width}x${reqBody?.height}`, label: "Dimensions" }, dimensions: { value: ({ reqBody }) => `${reqBody?.width}x${reqBody?.height}`, label: "Dimensions" },
sampler_name: "Sampler", sampler_name: "Sampler",
scheduler_name: {
label: "Scheduler",
visible: ({ reqBody }) => reqBody?.scheduler_name,
},
num_inference_steps: "Inference Steps", num_inference_steps: "Inference Steps",
guidance_scale: "Guidance Scale", guidance_scale: "Guidance Scale",
distilled_guidance_scale: {
label: "Distilled Guidance Scale",
visible: ({ reqBody }) => reqBody?.distilled_guidance_scale,
},
use_stable_diffusion_model: "Model", use_stable_diffusion_model: "Model",
clip_skip: { clip_skip: {
label: "Clip Skip", label: "Clip Skip",
@ -76,6 +84,8 @@ let numOutputsParallelField = document.querySelector("#num_outputs_parallel")
let numInferenceStepsField = document.querySelector("#num_inference_steps") let numInferenceStepsField = document.querySelector("#num_inference_steps")
let guidanceScaleSlider = document.querySelector("#guidance_scale_slider") let guidanceScaleSlider = document.querySelector("#guidance_scale_slider")
let guidanceScaleField = document.querySelector("#guidance_scale") let guidanceScaleField = document.querySelector("#guidance_scale")
let distilledGuidanceScaleSlider = document.querySelector("#distilled_guidance_scale_slider")
let distilledGuidanceScaleField = document.querySelector("#distilled_guidance_scale")
let outputQualitySlider = document.querySelector("#output_quality_slider") let outputQualitySlider = document.querySelector("#output_quality_slider")
let outputQualityField = document.querySelector("#output_quality") let outputQualityField = document.querySelector("#output_quality")
let outputQualityRow = document.querySelector("#output_quality_row") let outputQualityRow = document.querySelector("#output_quality_row")
@ -113,6 +123,8 @@ let promptStrengthSlider = document.querySelector("#prompt_strength_slider")
let promptStrengthField = document.querySelector("#prompt_strength") let promptStrengthField = document.querySelector("#prompt_strength")
let samplerField = document.querySelector("#sampler_name") let samplerField = document.querySelector("#sampler_name")
let samplerSelectionContainer = document.querySelector("#samplerSelection") let samplerSelectionContainer = document.querySelector("#samplerSelection")
let schedulerField = document.querySelector("#scheduler_name")
let schedulerSelectionContainer = document.querySelector("#schedulerSelection")
let useFaceCorrectionField = document.querySelector("#use_face_correction") let useFaceCorrectionField = document.querySelector("#use_face_correction")
let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model"), ["gfpgan", "codeformer"], "", false) let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model"), ["gfpgan", "codeformer"], "", false)
let useUpscalingField = document.querySelector("#use_upscale") let useUpscalingField = document.querySelector("#use_upscale")
@ -981,7 +993,20 @@ function onRedoFilter(req, img, e, tools) {
function onUpscaleClick(req, img, e, tools) { function onUpscaleClick(req, img, e, tools) {
let path = upscaleModelField.value let path = upscaleModelField.value
let scale = parseInt(upscaleAmountField.value) let scale = parseInt(upscaleAmountField.value)
let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler"
let filterName = null
const FILTERS = ["realesrgan", "latent_upscaler", "esrgan_4x", "lanczos", "nearest", "scunet", "swinir"]
for (let idx in FILTERS) {
let f = FILTERS[idx]
if (path.toLowerCase().includes(f)) {
filterName = f
break
}
}
if (!filterName) {
return
}
let statusText = "Upscaling by " + scale + "x using " + filterName let statusText = "Upscaling by " + scale + "x using " + filterName
applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools) applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools)
} }
@ -1038,6 +1063,9 @@ function makeImage() {
if (guidanceScaleField.value == "") { if (guidanceScaleField.value == "") {
guidanceScaleField.value = guidanceScaleSlider.value / 10 guidanceScaleField.value = guidanceScaleSlider.value / 10
} }
if (distilledGuidanceScaleField.value == "") {
distilledGuidanceScaleField.value = distilledGuidanceScaleSlider.value / 10
}
if (hypernetworkStrengthField.value == "") { if (hypernetworkStrengthField.value == "") {
hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100 hypernetworkStrengthField.value = hypernetworkStrengthSlider.value / 100
} }
@ -1406,6 +1434,12 @@ function getCurrentUserRequest() {
newTask.reqBody.control_filter_to_apply = controlImageFilterField.value newTask.reqBody.control_filter_to_apply = controlImageFilterField.value
} }
} }
if (stableDiffusionModelField.value.toLowerCase().includes("flux")) {
newTask.reqBody.distilled_guidance_scale = parseFloat(distilledGuidanceScaleField.value)
}
if (schedulerSelectionContainer.style.display !== "none") {
newTask.reqBody.scheduler_name = schedulerField.value
}
return newTask return newTask
} }
@ -1845,36 +1879,93 @@ controlImagePreview.addEventListener("load", onControlnetModelChange)
controlImagePreview.addEventListener("unload", onControlnetModelChange) controlImagePreview.addEventListener("unload", onControlnetModelChange)
onControlnetModelChange() onControlnetModelChange()
function onControlImageFilterChange() { // tip for Flux
let filterId = controlImageFilterField.value let sdModelField = document.querySelector("#stable_diffusion_model")
if (filterId.includes("openpose")) { function checkGuidanceValue() {
controlnetModelField.value = "control_v11p_sd15_openpose" let guidance = parseFloat(guidanceScaleField.value)
} else if (filterId === "canny") { let guidanceWarning = document.querySelector("#guidanceWarning")
controlnetModelField.value = "control_v11p_sd15_canny" let guidanceWarningText = document.querySelector("#guidanceWarningText")
} else if (filterId === "mlsd") { if (sdModelField.value.toLowerCase().includes("flux")) {
controlnetModelField.value = "control_v11p_sd15_mlsd" if (guidance > 1.5) {
} else if (filterId === "mlsd") { guidanceWarningText.innerText = "Flux recommends a 'Guidance Scale' of 1"
controlnetModelField.value = "control_v11p_sd15_mlsd" guidanceWarning.classList.remove("displayNone")
} else if (filterId.includes("scribble")) { } else {
controlnetModelField.value = "control_v11p_sd15_scribble" guidanceWarning.classList.add("displayNone")
} else if (filterId.includes("softedge")) { }
controlnetModelField.value = "control_v11p_sd15_softedge" } else {
} else if (filterId === "normal_bae") { if (guidance < 2) {
controlnetModelField.value = "control_v11p_sd15_normalbae" guidanceWarningText.innerText = "A higher 'Guidance Scale' is recommended!"
} else if (filterId.includes("depth")) { guidanceWarning.classList.remove("displayNone")
controlnetModelField.value = "control_v11f1p_sd15_depth" } else {
} else if (filterId === "lineart_anime") { guidanceWarning.classList.add("displayNone")
controlnetModelField.value = "control_v11p_sd15s2_lineart_anime"
} else if (filterId.includes("lineart")) {
controlnetModelField.value = "control_v11p_sd15_lineart"
} else if (filterId === "shuffle") {
controlnetModelField.value = "control_v11e_sd15_shuffle"
} else if (filterId === "segment") {
controlnetModelField.value = "control_v11p_sd15_seg"
} }
} }
controlImageFilterField.addEventListener("change", onControlImageFilterChange) }
onControlImageFilterChange() sdModelField.addEventListener("change", checkGuidanceValue)
guidanceScaleField.addEventListener("change", checkGuidanceValue)
guidanceScaleSlider.addEventListener("change", checkGuidanceValue)
function checkGuidanceScaleVisibility() {
let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container")
if (sdModelField.value.toLowerCase().includes("flux")) {
guidanceScaleContainer.classList.remove("displayNone")
} else {
guidanceScaleContainer.classList.add("displayNone")
}
}
sdModelField.addEventListener("change", checkGuidanceScaleVisibility)
function checkFluxSampler() {
let samplerWarning = document.querySelector("#fluxSamplerWarning")
if (sdModelField.value.toLowerCase().includes("flux")) {
if (samplerField.value == "euler_a") {
samplerWarning.classList.remove("displayNone")
} else {
samplerWarning.classList.add("displayNone")
}
} else {
samplerWarning.classList.add("displayNone")
}
}
sdModelField.addEventListener("change", checkFluxSampler)
samplerField.addEventListener("change", checkFluxSampler)
document.addEventListener("refreshModels", function() {
checkGuidanceValue()
checkGuidanceScaleVisibility()
checkFluxSampler()
})
// function onControlImageFilterChange() {
// let filterId = controlImageFilterField.value
// if (filterId.includes("openpose")) {
// controlnetModelField.value = "control_v11p_sd15_openpose"
// } else if (filterId === "canny") {
// controlnetModelField.value = "control_v11p_sd15_canny"
// } else if (filterId === "mlsd") {
// controlnetModelField.value = "control_v11p_sd15_mlsd"
// } else if (filterId === "mlsd") {
// controlnetModelField.value = "control_v11p_sd15_mlsd"
// } else if (filterId.includes("scribble")) {
// controlnetModelField.value = "control_v11p_sd15_scribble"
// } else if (filterId.includes("softedge")) {
// controlnetModelField.value = "control_v11p_sd15_softedge"
// } else if (filterId === "normal_bae") {
// controlnetModelField.value = "control_v11p_sd15_normalbae"
// } else if (filterId.includes("depth")) {
// controlnetModelField.value = "control_v11f1p_sd15_depth"
// } else if (filterId === "lineart_anime") {
// controlnetModelField.value = "control_v11p_sd15s2_lineart_anime"
// } else if (filterId.includes("lineart")) {
// controlnetModelField.value = "control_v11p_sd15_lineart"
// } else if (filterId === "shuffle") {
// controlnetModelField.value = "control_v11e_sd15_shuffle"
// } else if (filterId === "segment") {
// controlnetModelField.value = "control_v11p_sd15_seg"
// }
// }
// controlImageFilterField.addEventListener("change", onControlImageFilterChange)
// onControlImageFilterChange()
upscaleModelField.disabled = !useUpscalingField.checked upscaleModelField.disabled = !useUpscalingField.checked
upscaleAmountField.disabled = !useUpscalingField.checked upscaleAmountField.disabled = !useUpscalingField.checked
@ -1973,6 +2064,27 @@ guidanceScaleSlider.addEventListener("input", updateGuidanceScale)
guidanceScaleField.addEventListener("input", updateGuidanceScaleSlider) guidanceScaleField.addEventListener("input", updateGuidanceScaleSlider)
updateGuidanceScale() updateGuidanceScale()
/********************* Distilled Guidance **************************/
function updateDistilledGuidanceScale() {
distilledGuidanceScaleField.value = distilledGuidanceScaleSlider.value / 10
distilledGuidanceScaleField.dispatchEvent(new Event("change"))
}
function updateDistilledGuidanceScaleSlider() {
if (distilledGuidanceScaleField.value < 0) {
distilledGuidanceScaleField.value = 0
} else if (distilledGuidanceScaleField.value > 50) {
distilledGuidanceScaleField.value = 50
}
distilledGuidanceScaleSlider.value = distilledGuidanceScaleField.value * 10
distilledGuidanceScaleSlider.dispatchEvent(new Event("change"))
}
distilledGuidanceScaleSlider.addEventListener("input", updateDistilledGuidanceScale)
distilledGuidanceScaleField.addEventListener("input", updateDistilledGuidanceScaleSlider)
updateDistilledGuidanceScale()
/********************* Prompt Strength *******************/ /********************* Prompt Strength *******************/
function updatePromptStrength() { function updatePromptStrength() {
promptStrengthField.value = promptStrengthSlider.value / 100 promptStrengthField.value = promptStrengthSlider.value / 100

View File

@ -102,7 +102,7 @@ var PARAMETERS = [
type: ParameterType.custom, type: ParameterType.custom,
icon: "fa-folder-tree", icon: "fa-folder-tree",
label: "Models Folder", label: "Models Folder",
note: "Path to the 'models' folder. Please save and refresh the page after changing this.", note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.",
saveInAppConfig: true, saveInAppConfig: true,
render: (parameter) => { render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="30">` return `<input id="${parameter.id}" name="${parameter.id}" size="30">`
@ -161,6 +161,7 @@ var PARAMETERS = [
"<b>Low:</b> slowest, recommended for GPUs with 3 to 4 GB memory", "<b>Low:</b> slowest, recommended for GPUs with 3 to 4 GB memory",
icon: "fa-forward", icon: "fa-forward",
default: "balanced", default: "balanced",
saveInAppConfig: true,
options: [ options: [
{ value: "balanced", label: "Balanced" }, { value: "balanced", label: "Balanced" },
{ value: "high", label: "High" }, { value: "high", label: "High" },
@ -249,14 +250,19 @@ var PARAMETERS = [
default: false, default: false,
}, },
{ {
id: "use_v3_engine", id: "backend",
type: ParameterType.checkbox, type: ParameterType.select,
label: "Use the new v3 engine (diffusers)", label: "Engine to use",
note: note:
"Use our new v3 engine, with additional features like LoRA, ControlNet, SDXL, Embeddings, Tiling and lots more! Please press Save, then restart the program after changing this.", "Use our new v3.5 engine (Forge), with additional features like Flux, SD3, Lycoris and lots more! Please press Save, then restart the program after changing this.",
icon: "fa-bolt", icon: "fa-robot",
default: true,
saveInAppConfig: true, saveInAppConfig: true,
default: "ed_diffusers",
options: [
{ value: "webui", label: "v3.5 (latest)" },
{ value: "ed_diffusers", label: "v3.0" },
{ value: "ed_classic", label: "v2.0" },
],
}, },
{ {
id: "cloudflare", id: "cloudflare",
@ -432,6 +438,7 @@ let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions")
let testDiffusers = document.querySelector("#use_v3_engine") let testDiffusers = document.querySelector("#use_v3_engine")
let backendEngine = document.querySelector("#backend")
let profileNameField = document.querySelector("#profileName") let profileNameField = document.querySelector("#profileName")
let modelsDirField = document.querySelector("#models_dir") let modelsDirField = document.querySelector("#models_dir")
@ -454,6 +461,23 @@ async function changeAppConfig(configDelta) {
} }
} }
function getDefaultDisplay(element) {
const tag = element.tagName.toLowerCase();
const defaultDisplays = {
div: 'block',
span: 'inline',
p: 'block',
tr: 'table-row',
table: 'table',
li: 'list-item',
ul: 'block',
ol: 'block',
button: 'inline',
// Add more if needed
};
return defaultDisplays[tag] || 'block'; // Default to 'block' if not listed
}
async function getAppConfig() { async function getAppConfig() {
try { try {
let res = await fetch("/get/app_config") let res = await fetch("/get/app_config")
@ -478,14 +502,16 @@ async function getAppConfig() {
modelsDirField.value = config.models_dir modelsDirField.value = config.models_dir
let testDiffusersEnabled = true let testDiffusersEnabled = true
if (config.use_v3_engine === false) { if (config.backend === "ed_classic") {
testDiffusersEnabled = false testDiffusersEnabled = false
} }
testDiffusers.checked = testDiffusersEnabled testDiffusers.checked = testDiffusersEnabled
backendEngine.value = config.backend
document.querySelector("#test_diffusers").checked = testDiffusers.checked // don't break plugins document.querySelector("#test_diffusers").checked = testDiffusers.checked // don't break plugins
document.querySelector("#use_v3_engine").checked = testDiffusers.checked // don't break plugins
if (config.config_on_startup) { if (config.config_on_startup) {
if (config.config_on_startup?.use_v3_engine) { if (config.config_on_startup?.backend !== "ed_classic") {
document.body.classList.add("diffusers-enabled-on-startup") document.body.classList.add("diffusers-enabled-on-startup")
document.body.classList.remove("diffusers-disabled-on-startup") document.body.classList.remove("diffusers-disabled-on-startup")
} else { } else {
@ -494,36 +520,26 @@ async function getAppConfig() {
} }
} }
if (!testDiffusersEnabled) { if (config.backend === "ed_classic") {
document.querySelector("#lora_model_container").style.display = "none"
document.querySelector("#tiling_container").style.display = "none"
document.querySelector("#controlnet_model_container").style.display = "none"
document.querySelector("#hypernetwork_model_container").style.display = ""
document.querySelector("#hypernetwork_strength_container").style.display = ""
document.querySelector("#negative-embeddings-button").style.display = "none"
document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => {
option.style.display = "none"
})
IMAGE_STEP_SIZE = 64 IMAGE_STEP_SIZE = 64
customWidthField.step = IMAGE_STEP_SIZE
customHeightField.step = IMAGE_STEP_SIZE
} else { } else {
document.querySelector("#lora_model_container").style.display = ""
document.querySelector("#tiling_container").style.display = ""
document.querySelector("#controlnet_model_container").style.display = ""
document.querySelector("#hypernetwork_model_container").style.display = "none"
document.querySelector("#hypernetwork_strength_container").style.display = "none"
document.querySelectorAll("#sampler_name option.k_diffusion-only").forEach((option) => {
option.style.display = "none"
})
document.querySelector("#clip_skip_config").classList.remove("displayNone")
document.querySelector("#embeddings-button").classList.remove("displayNone")
IMAGE_STEP_SIZE = 8 IMAGE_STEP_SIZE = 8
}
customWidthField.step = IMAGE_STEP_SIZE customWidthField.step = IMAGE_STEP_SIZE
customHeightField.step = IMAGE_STEP_SIZE customHeightField.step = IMAGE_STEP_SIZE
const currentBackendKey = "backend_" + config.backend
document.querySelectorAll('.gated-feature').forEach((element) => {
const featureKeys = element.getAttribute('data-feature-keys').split(' ')
if (featureKeys.includes(currentBackendKey)) {
element.style.display = getDefaultDisplay(element)
} else {
element.style.display = 'none'
} }
});
if (config.force_save_metadata) { if (config.force_save_metadata) {
metadataOutputFormatField.value = config.force_save_metadata metadataOutputFormatField.value = config.force_save_metadata
@ -749,6 +765,11 @@ async function getSystemInfo() {
metadataOutputFormatField.disabled = !saveToDiskField.checked metadataOutputFormatField.disabled = !saveToDiskField.checked
} }
setDiskPath(res["default_output_dir"], force) setDiskPath(res["default_output_dir"], force)
// backend info
if (res["backend_url"]) {
document.querySelector("#backend-url").setAttribute("href", res["backend_url"])
}
} catch (e) { } catch (e) {
console.log("error fetching devices", e) console.log("error fetching devices", e)
} }