Compare commits

...

56 Commits

Author SHA1 Message Date
404329f9b5 Fix for image modifier improvements plugin 2023-08-03 16:33:20 +05:30
3929e88d87 Include the lora parser plugin as a core feature 2023-08-03 16:20:27 +05:30
83a5b5b46f Clamp controlnet images to multiples of 8 2023-08-03 15:51:39 +05:30
b97c906128 Fix a bug where setting an initial image would mess up the width and height field 2023-08-03 15:49:01 +05:30
b8328b6071 sdkit 1.0.165 - warn users about incompatible loras 2023-08-03 15:14:00 +05:30
9a528496a3 Reload the model if the path exists in the request but the model has been unloaded 2023-08-03 15:13:41 +05:30
6a95c602b1 sdkit 1.0.164 - Warn the user if the controlnet isn't compatible with the SD model version 2023-08-03 12:43:30 +05:30
f0f6578b9c Round image sizes to a multiple of 8 2023-08-03 10:22:24 +05:30
83c93eb9ef sdkit 1.0.163 - trt multi-gpu fix 2023-08-02 21:53:11 +05:30
befe8ad24e TRT logging 2023-08-02 18:55:09 +05:30
c5249e6144 TRT styling 2023-08-02 16:45:12 +05:30
9be3297c27 sdkit 1.0.162 - bug fixes for TRT 2023-08-02 16:42:24 +05:30
b6344ef6f9 sdkit 1.0.161 - bug fixes for TRT 2023-08-02 16:37:19 +05:30
76b7e32125 Bug fixes for TRT 2023-08-02 16:37:05 +05:30
801a3dd598 sdkit 1.0.160 - Dynamic load/unload of TensorRT engines 2023-08-02 15:34:55 +05:30
d1fdf1766a Allow batch size ranges again in TRT 2023-08-02 14:03:59 +05:30
35073adc1f Merge branch 'beta' of github.com:cmdr2/stable-diffusion-ui into beta 2023-08-02 12:36:15 +05:30
d76930c7f4 sdkit 1.0.159 - typo in TensorRT forward 2023-08-02 12:35:44 +05:30
7d496f4ad0 Add ControlNet model and filter to metadata (#1454) 2023-08-02 10:13:21 +05:30
53b5ce6e2c typo 2023-08-02 00:10:19 +05:30
38ab5b090f TRT ui changes 2023-08-02 00:08:43 +05:30
fa58996f37 sdkit 1.0.157 - tensorRT build configuration from the UI; clamp images to 8 instead of 64 pixels 2023-08-01 23:53:01 +05:30
56f92ccab0 Don't restrict TRT to batch size 1 2023-08-01 21:24:22 +05:30
4e444b418e sdkit 1.0.156 - missing jsons for controlnet 2023-08-01 18:23:34 +05:30
3d9a9299dc changelog 2023-08-01 17:42:15 +05:30
ae34c9e84b Download known controlnet models if selected; Auto-pick the recommended controlnet model when a filter is selected 2023-08-01 17:39:04 +05:30
eba7bab15e Allow named models in the dropdown 2023-08-01 16:16:38 +05:30
ee6db85768 Initial support for Controlnet 2023-08-01 15:39:15 +05:30
05ed110519 Don't show parallel field for tensorrt demo 2023-08-01 13:02:53 +05:30
9690fd1fa8 sdkit 1.0.154 - restrict tensorrt from 768x768 to 1024x1024, and Unet-only, to avoid going out of memory 2023-08-01 12:45:56 +05:30
4cee1be99c Default settings for TensorRT demo; Don't show splash screen for diffusers 2023-08-01 12:43:23 +05:30
d39e1da183 Fixes for TensorRT 2023-08-01 11:49:30 +05:30
8538a684e7 sdkit 1.0.153 - use TensorRT only if enabled in the UI 2023-07-31 13:19:56 +05:30
47d7513dd8 sdkit 1.0.152 - fix for black images with TensorRT, and enable a timing cache 2023-07-31 12:54:05 +05:30
432fd57581 Use the desired output format and quality while applying the quick filter 2023-07-30 14:06:31 +05:30
9c06e2612a changelog 2023-07-30 13:53:27 +05:30
1d6742f463 sdkit 1.0.151 - An option to use strict mask borders 2023-07-30 13:51:19 +05:30
2e849827d1 Restore width/height dropdown (#1445) 2023-07-30 10:16:04 +05:30
1e2c9ecb41 Use nvidia pypi index url for linux 2023-07-29 22:24:34 +05:30
14679586a8 changelog 2023-07-29 22:04:03 +05:30
11fb83a2a7 sdkit 1.0.148 - fix watermarking which is causing image artifacts in SDXL; fix SDXL long prompts with compel 2.0.1 2023-07-29 22:03:39 +05:30
4d3f55622a Support more image sizes (#1441)
* Support more image sizes
With diffusers, width and height must be a multiple of 8 (instead of 64), allowing more resolution values.

* Add swap button

* Change popup button icon
2023-07-29 21:42:48 +05:30
eedf6f0aad changelog 2023-07-29 21:30:43 +05:30
13592fae1a sdkit 1.0.147 - diffusers 0.19.2 - fix red specs in SDXL images 2023-07-29 21:29:24 +05:30
4dd05d3efe Merge branch 'trt' into beta 2023-07-29 21:10:00 +05:30
2e3059a7c8 UI for TensorRT installation and conversion 2023-07-29 21:09:27 +05:30
3b53b5ebaf sdkit 1.0.144 - use prompts for SDXL refiner; use VAE slicing and VAE tiling for larger images 2023-07-29 12:42:34 +05:30
a9f1000af8 Install button for TensorRT - displayed only if an NVIDIA gpu is active 2023-07-29 11:41:44 +05:30
a9960ded01 Styling 2023-07-29 10:14:52 +05:30
ed84a23f36 Redo button for image filters, limit undo buffer size to 5 2023-07-29 10:07:41 +05:30
8301cafb37 changelog 2023-07-29 09:23:15 +05:30
c906c5d14a Don't rely on old keys to exist in the request 2023-07-29 09:14:00 +05:30
6e52680fa8 Fast in-place upscale and face fix buttons, with an option to undo the operations 2023-07-28 22:48:41 +05:30
7f32c531d7 sdkit 1.0.143 - Fixes for the new beta 2023-07-28 19:14:40 +05:30
17a11b94b2 changelog 2023-07-28 18:59:10 +05:30
e61549e0cd Mega refactor of the task processing and rendering logic; Split filter into a separate task, and add support for running filter tasks individually; Change the format for sending model and filter data from the API, but maintain backwards compatibility for now with the old API 2023-07-28 18:57:28 +05:30
25 changed files with 2112 additions and 588 deletions

View File

@ -22,6 +22,12 @@
Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed.
### Detailed changelog
* 2.5.48 - 1 Aug 2023 - (beta-only) Full support for ControlNets. You can select a control image to guide the AI. You can pick a filter to pre-process the image, and one of the known (or custom) controlnet models. Supports `OpenPose`, `Canny`, `Straight Lines`, `Depth`, `Line Art`, `Scribble`, `Soft Edge`, `Shuffle` and `Segment`.
* 2.5.47 - 30 Jul 2023 - An option to use `Strict Mask Border` while inpainting, to avoid touching areas outside the mask. But this might show a slight outline of the mask, which you will have to touch up separately.
* 2.5.47 - 29 Jul 2023 - (beta-only) Fix long prompts with SDXL.
* 2.5.47 - 29 Jul 2023 - (beta-only) Fix red dots in some SDXL images.
* 2.5.47 - 29 Jul 2023 - Significantly faster `Fix Faces` and `Upscale` buttons (on the image). They no longer need to generate the image from scratch, instead they just upscale/fix the generated image in-place.
* 2.5.47 - 28 Jul 2023 - Lots of internal code reorganization, in preparation for supporting Controlnets. No user-facing changes.
* 2.5.46 - 27 Jul 2023 - (beta-only) Full support for SD-XL models (base and refiner)!
* 2.5.45 - 24 Jul 2023 - (beta-only) Hide the samplers that won't be supported in the new diffusers version.
* 2.5.45 - 22 Jul 2023 - (beta-only) Fix the recently-broken inpainting models.

View File

@ -18,7 +18,7 @@ os_name = platform.system()
modules_to_check = {
"torch": ("1.11.0", "1.13.1", "2.0.0"),
"torchvision": ("0.12.0", "0.14.1", "0.15.1"),
"sdkit": "1.0.142",
"sdkit": "1.0.165",
"stable-diffusion-sdkit": "2.1.4",
"rich": "12.6.0",
"uvicorn": "0.19.0",

View File

@ -32,6 +32,8 @@ logging.basicConfig(
SD_DIR = os.getcwd()
ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, ".."))
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))

View File

@ -5,10 +5,11 @@ import traceback
from typing import Union
from easydiffusion import app
from easydiffusion.types import TaskData
from easydiffusion.types import ModelsData
from easydiffusion.utils import log
from sdkit import Context
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = [
@ -19,6 +20,8 @@ KNOWN_MODEL_TYPES = [
"realesrgan",
"lora",
"codeformer",
"embeddings",
"controlnet",
]
MODEL_EXTENSIONS = {
"stable-diffusion": [".ckpt", ".safetensors"],
@ -29,6 +32,7 @@ MODEL_EXTENSIONS = {
"lora": [".ckpt", ".safetensors"],
"codeformer": [".pth"],
"embeddings": [".pt", ".bin", ".safetensors"],
"controlnet": [".pth", ".safetensors"],
}
DEFAULT_MODELS = {
"stable-diffusion": [
@ -57,7 +61,9 @@ def init():
def load_default_models(context: Context):
set_vram_optimizations(context)
from easydiffusion import runtime
runtime.set_vram_optimizations(context)
config = app.getConfig()
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
@ -138,43 +144,32 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
def reload_models_if_necessary(context: Context, task_data: TaskData):
face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else ""
upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else ""
model_paths_in_req = {
"stable-diffusion": task_data.use_stable_diffusion_model,
"vae": task_data.use_vae_model,
"hypernetwork": task_data.use_hypernetwork_model,
"codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None,
"gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None,
"realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None,
"latent_upscaler": True if "latent_upscaler" in upscale_lower else None,
"nsfw_checker": True if task_data.block_nsfw else None,
"lora": task_data.use_lora_model,
}
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
models_to_reload = {
model_type: path
for model_type, path in model_paths_in_req.items()
if context.model_paths.get(model_type) != path
for model_type, path in models_data.model_paths.items()
if context.model_paths.get(model_type) != path or (path is not None and context.models.get(model_type) is None)
}
if task_data.codeformer_upscale_faces:
if models_data.model_paths.get("codeformer"):
if "realesrgan" not in models_to_reload and "realesrgan" not in context.models:
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None:
del models_to_reload["realesrgan"] # don't unload realesrgan
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
for model_type in models_to_force_reload:
if model_type not in models_data.model_paths:
continue
models_to_reload[model_type] = models_data.model_paths[model_type]
for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model
extra_params = models_data.model_params.get(model_type, {})
try:
action_fn(context, model_type, scan_model=False) # we've scanned them already
action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already
if model_type in context.model_load_errors:
del context.model_load_errors[model_type]
except Exception as e:
@ -183,24 +178,22 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = resolve_model_to_use(
task_data.use_stable_diffusion_model, model_type="stable-diffusion"
)
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora")
if task_data.use_face_correction:
if "gfpgan" in task_data.use_face_correction.lower():
model_type = "gfpgan"
elif "codeformer" in task_data.use_face_correction.lower():
model_type = "codeformer"
def resolve_model_paths(models_data: ModelsData):
model_paths = models_data.model_paths
for model_type in model_paths:
skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"]
if model_type in skip_models: # doesn't use model paths
continue
if model_type == "codeformer":
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
elif model_type == "controlnet":
model_id = model_paths[model_type]
model_info = get_model_info_from_db(model_type=model_type, model_id=model_id)
if model_info:
filename = model_info.get("url", "").split("/")[-1]
download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False)
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type)
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
def fail_if_models_did_not_load(context: Context):
@ -222,28 +215,17 @@ def download_default_models_if_necessary():
print(model_type, "model(s) found.")
def download_if_necessary(model_type: str, file_name: str, model_id: str):
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
model_path = os.path.join(app.MODELS_DIR, model_type, file_name)
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
other_models_exist = any_model_exists(model_type)
other_models_exist = any_model_exists(model_type) and skip_if_others_exist
known_model_exists = os.path.exists(model_path)
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
if known_model_is_corrupt or (not other_models_exist and not known_model_exists):
print("> download", model_type, model_id)
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR)
def set_vram_optimizations(context: Context):
config = app.getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level
return True
return False
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False)
def migrate_legacy_model_location():
@ -266,16 +248,6 @@ def any_model_exists(model_type: str) -> bool:
return False
def set_clip_skip(context: Context, task_data: TaskData):
clip_skip = task_data.clip_skip
if clip_skip != context.clip_skip:
context.clip_skip = clip_skip
return True
return False
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
@ -324,12 +296,26 @@ def is_malicious_model(file_path):
def getModels(scan_for_malicious: bool = True):
models = {
"options": {
"stable-diffusion": ["sd-v1-4"],
"stable-diffusion": [{"sd-v1-4": "SD 1.4"}],
"vae": [],
"hypernetwork": [],
"lora": [],
"codeformer": ["codeformer"],
"codeformer": [{"codeformer": "CodeFormer"}],
"embeddings": [],
"controlnet": [
{"control_v11p_sd15_canny": "Canny (*)"},
{"control_v11p_sd15_openpose": "OpenPose (*)"},
{"control_v11p_sd15_normalbae": "Normal BAE (*)"},
{"control_v11f1p_sd15_depth": "Depth (*)"},
{"control_v11p_sd15_scribble": "Scribble"},
{"control_v11p_sd15_softedge": "Soft Edge"},
{"control_v11p_sd15_inpaint": "Inpaint"},
{"control_v11p_sd15_lineart": "Line Art"},
{"control_v11p_sd15s2_lineart_anime": "Line Art Anime"},
{"control_v11p_sd15_mlsd": "Straight Lines"},
{"control_v11p_sd15_seg": "Segment"},
{"control_v11e_sd15_shuffle": "Shuffle"},
],
},
}
@ -338,9 +324,9 @@ def getModels(scan_for_malicious: bool = True):
class MaliciousModelException(Exception):
"Raised when picklescan reports a problem with a model"
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]):
tree = list(default_entries)
nonlocal models_scanned
tree = []
for entry in sorted(
os.scandir(directory),
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
@ -359,7 +345,14 @@ def getModels(scan_for_malicious: bool = True):
raise MaliciousModelException(entry.path)
if scan_for_malicious:
known_models[entry.path] = mtime
tree.append(entry.name[: -len(matching_suffix)])
model_id = entry.name[: -len(matching_suffix)]
model_exists = False
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m):
model_exists = True
break
if not model_exists:
tree.append(model_id)
elif entry.is_dir():
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
@ -376,7 +369,8 @@ def getModels(scan_for_malicious: bool = True):
os.makedirs(models_dir)
try:
models["options"][model_type] = scan_directory(models_dir, model_extensions)
default_tree = models["options"].get(model_type, [])
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree)
except MaliciousModelException as e:
models["scan-error"] = str(e)
@ -389,6 +383,7 @@ def getModels(scan_for_malicious: bool = True):
listModels(model_type="gfpgan")
listModels(model_type="lora")
listModels(model_type="embeddings")
listModels(model_type="controlnet")
if scan_for_malicious and models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")

View File

@ -0,0 +1,98 @@
import sys
import os
import platform
from importlib.metadata import version as pkg_version
from sdkit.utils import log
from easydiffusion import app
# future home of scripts/check_modules.py
manifest = {
"tensorrt": {
"install": [
"nvidia-cudnn --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt-libs --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
],
"uninstall": ["tensorrt"],
# TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error)
}
}
installing = []
# remove this once TRT releases on pypi
if platform.system() == "Windows":
trt_dir = os.path.join(app.ROOT_DIR, "tensorrt")
if os.path.exists(trt_dir):
files = os.listdir(trt_dir)
packages = manifest["tensorrt"]["install"]
packages = tuple(p.replace("-", "_") for p in packages)
wheels = []
for p in packages:
p = p.split(" ")[0]
f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None)
if f:
wheels.append(os.path.join(trt_dir, f))
manifest["tensorrt"]["install"] = wheels
def get_installed_packages() -> list:
return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)}
def is_installed(module_name) -> bool:
return version(module_name) is not None
def install(module_name):
if is_installed(module_name):
log.info(f"{module_name} has already been installed!")
return
if module_name in installing:
log.info(f"{module_name} is already installing!")
return
if module_name not in manifest:
raise RuntimeError(f"Can't install unknown package: {module_name}!")
commands = manifest[module_name]["install"]
commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands]
installing.append(module_name)
try:
for cmd in commands:
print(">", cmd)
if os.system(cmd) != 0:
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
finally:
installing.remove(module_name)
def uninstall(module_name):
if not is_installed(module_name):
log.info(f"{module_name} hasn't been installed!")
return
if module_name not in manifest:
raise RuntimeError(f"Can't uninstall unknown package: {module_name}!")
commands = manifest[module_name]["uninstall"]
commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands]
for cmd in commands:
print(">", cmd)
if os.system(cmd) != 0:
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
def version(module_name: str) -> str:
try:
return pkg_version(module_name)
except:
return None

View File

@ -1,279 +0,0 @@
import json
import pprint
import queue
import time
from easydiffusion import device_manager
from easydiffusion.types import GenerateImageRequest
from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import Response, TaskData, UserInitiatedStop
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit import Context
from sdkit.filter import apply_filters
from sdkit.generate import generate_images
from sdkit.models import load_model
from sdkit.utils import (
diffusers_latent_samples_to_images,
gc,
img_to_base64_str,
img_to_buffer,
latent_samples_to_images,
get_device_usage,
)
context = Context() # thread-local
"""
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device):
"""
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
"""
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
context.model_load_errors = {}
context.enable_codeformer = True
from easydiffusion import app
app_config = app.getConfig()
context.test_diffusers = (
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
)
log.info("Device usage during initialization:")
get_device_usage(device, log_info=True, process_usage_only=False)
device_manager.device_init(context, device)
def make_images(
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
):
context.stop_processing = False
print_task_info(req, task_data)
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
res = Response(
req,
task_data,
images=construct_response(images, seeds, task_data, base_seed=req.seed),
)
res = res.json()
data_queue.put(json.dumps(res))
log.info("Task completed")
return res
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"task data: {task_str}")
def make_images_internal(
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
):
images, user_stopped = generate_images_internal(
req,
task_data,
data_queue,
task_temp_images,
step_callback,
task_data.stream_image_progress,
task_data.stream_image_progress_interval,
)
gc(context)
filtered_images = filter_images(req, task_data, images, user_stopped)
if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data)
seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images:
return filtered_images, seeds
else:
return images + filtered_images, seeds + seeds
def generate_images_internal(
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
):
context.temp_images.clear()
callback = make_step_callback(
req,
task_data,
data_queue,
task_temp_images,
step_callback,
stream_image_progress,
stream_image_progress_interval,
)
try:
if req.init_image is not None and not context.test_diffusers:
req.sampler_name = "ddim"
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if context.partial_x_samples is not None:
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
else:
images = latent_samples_to_images(context, context.partial_x_samples)
finally:
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
if not context.test_diffusers:
del context.partial_x_samples
context.partial_x_samples = None
return images, user_stopped
def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped):
if user_stopped:
return images
if task_data.block_nsfw:
images = apply_filters(context, "nsfw_checker", images)
if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths["realesrgan"]
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
try:
images = apply_filters(
context,
"codeformer",
images,
upscale_faces=task_data.codeformer_upscale_faces,
codeformer_fidelity=task_data.codeformer_fidelity,
)
finally:
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
images = apply_filters(context, "gfpgan", images)
if task_data.use_upscale:
if "realesrgan" in task_data.use_upscale.lower():
images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount)
elif task_data.use_upscale == "latent_upscaler":
images = apply_filters(
context,
"latent_upscaler",
images,
scale=task_data.upscale_amount,
latent_upscaler_options={
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"seed": req.seed,
"num_inference_steps": task_data.latent_upscaler_steps,
"guidance_scale": 0,
},
)
return images
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
return [
ResponseImage(
data=img_to_base64_str(
img,
task_data.output_format,
task_data.output_quality,
task_data.output_lossless,
),
seed=seed,
)
for img, seed in zip(images, seeds)
]
def make_step_callback(
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, x_samples)
else:
images = latent_samples_to_images(context, x_samples)
if task_data.block_nsfw:
images = apply_filters(context, "nsfw_checker", images)
for i, img in enumerate(images):
buf = img_to_buffer(img, output_format="JPEG")
context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
del images
return partial_images
def on_image_step(x_samples, i, *args):
nonlocal last_callback_time
if context.test_diffusers:
context.partial_x_samples = (x_samples, args[0])
else:
context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
step_callback()
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

View File

@ -0,0 +1,53 @@
"""
A runtime that runs on a specific device (in a thread).
It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context.
This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function.
"""
from easydiffusion import device_manager
from easydiffusion.utils import log
from sdkit import Context
from sdkit.utils import get_device_usage
context = Context() # thread-local
"""
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device):
"""
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
"""
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
context.model_load_errors = {}
context.enable_codeformer = True
from easydiffusion import app
app_config = app.getConfig()
context.test_diffusers = (
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
)
log.info("Device usage during initialization:")
get_device_usage(device, log_info=True, process_usage_only=False)
device_manager.device_init(context, device)
def set_vram_optimizations(context: Context):
from easydiffusion import app
config = app.getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level
return True
return False

View File

@ -8,8 +8,17 @@ import os
import traceback
from typing import List, Union
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
from easydiffusion import app, model_manager, task_manager, package_manager
from easydiffusion.tasks import RenderTask, FilterTask
from easydiffusion.types import (
GenerateImageRequest,
FilterImageRequest,
MergeRequest,
TaskData,
ModelsData,
OutputFormatData,
convert_legacy_render_req_to_new,
)
from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
@ -97,6 +106,10 @@ def init():
def render(req: dict):
return render_internal(req)
@server_api.post("/filter")
def render(req: dict):
return filter_internal(req)
@server_api.post("/model/merge")
def model_merge(req: dict):
print(req)
@ -122,6 +135,10 @@ def init():
def stop_cloudflare_tunnel(req: dict):
return stop_cloudflare_tunnel_internal(req)
@server_api.post("/package/{package_name:str}")
def modify_package(package_name: str, req: dict):
return modify_package_internal(package_name, req)
@server_api.get("/")
def read_root():
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
@ -213,24 +230,36 @@ def ping_internal(session_id: str = None):
if task_manager.current_state_error:
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail="Render thread is dead.")
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {"status": str(task_manager.current_state)}
if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True)
response["tasks"] = {id(t): t.status for t in session.tasks}
response["devices"] = task_manager.get_devices()
response["packages_installed"] = package_manager.get_installed_packages()
response["packages_installing"] = package_manager.installing
if cloudflare.address != None:
response["cloudflare"] = cloudflare.address
return JSONResponse(response, headers=NOCACHE_HEADERS)
def render_internal(req: dict):
try:
req = convert_legacy_render_req_to_new(req)
# separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req)
models_data: ModelsData = ModelsData.parse_obj(req)
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
# Overwrite user specified save path
config = app.getConfig()
@ -240,28 +269,53 @@ def render_internal(req: dict):
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
app.save_to_config(
task_data.use_stable_diffusion_model,
task_data.use_vae_model,
task_data.use_hypernetwork_model,
models_data.model_paths.get("stable-diffusion"),
models_data.model_paths.get("vae"),
models_data.model_paths.get("hypernetwork"),
task_data.vram_usage_level,
)
# enqueue the task
new_task = task_manager.render(render_req, task_data)
task = RenderTask(render_req, task_data, models_data, output_format)
return enqueue_task(task)
except HTTPException as e:
raise e
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def filter_internal(req: dict):
try:
session_id = req.get("session_id", "session")
filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req)
models_data: ModelsData = ModelsData.parse_obj(req)
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
# enqueue the task
task = FilterTask(filter_req, session_id, models_data, output_format)
return enqueue_task(task)
except HTTPException as e:
raise e
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def enqueue_task(task):
try:
task_manager.enqueue_task(task)
response = {
"status": str(task_manager.current_state),
"queue": len(task_manager.tasks_queue),
"stream": f"/image/stream/{id(new_task)}",
"task": id(new_task),
"stream": f"/image/stream/{task.id}",
"task": task.id,
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def model_merge_internal(req: dict):
@ -381,3 +435,19 @@ def stop_cloudflare_tunnel_internal(req: dict):
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def modify_package_internal(package_name: str, req: dict):
try:
cmd = req["command"]
if cmd not in ("install", "uninstall"):
raise RuntimeError(f"Unknown command: {cmd}")
cmd = getattr(package_manager, cmd)
cmd(package_name)
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
except Exception as e:
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))

View File

@ -17,7 +17,7 @@ from typing import Any, Hashable
import torch
from easydiffusion import device_manager
from easydiffusion.types import GenerateImageRequest, TaskData
from easydiffusion.tasks import Task
from easydiffusion.utils import log
from sdkit.utils import gc
@ -27,6 +27,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
MAX_OVERLOAD_ALLOWED_RATIO = 2 # i.e. 2x pending tasks compared to the number of render threads
class SymbolClass(type): # Print nicely formatted Symbol names.
@ -58,46 +59,6 @@ class ServerStates:
pass
class RenderTask: # Task with output queue and completion lock.
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
task_data.request_id = id(self)
self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except queue.Empty as e:
yield
@property
def status(self):
if self.lock.locked():
return "running"
if isinstance(self.error, StopAsyncIteration):
return "stopped"
if self.error:
return "error"
if not self.buffer_queue.empty():
return "buffer"
if self.response:
return "completed"
return "pending"
@property
def is_pending(self):
return bool(not self.response and not self.error)
# Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache:
def __init__(self):
@ -123,8 +84,8 @@ class DataCache:
# Remove Items
for key in to_delete:
(_, val) = self._base[key]
if isinstance(val, RenderTask):
log.debug(f"RenderTask {key} expired. Data removed.")
if isinstance(val, Task):
log.debug(f"Task {key} expired. Data removed.")
elif isinstance(val, SessionState):
log.debug(f"Session {key} expired. Data removed.")
else:
@ -220,8 +181,8 @@ class SessionState:
tasks.append(task)
return tasks
def put(self, task, ttl=TASK_TTL):
task_id = id(task)
def put(self, task: Task, ttl=TASK_TTL):
task_id = task.id
self._tasks_ids.append(task_id)
if not task_cache.put(task_id, task, ttl):
return False
@ -230,11 +191,16 @@ class SessionState:
return True
def keep_task_alive(task: Task):
task_cache.keep(task.id, TASK_TTL)
session_cache.keep(task.session_id, TASK_TTL)
def thread_get_next_task():
from easydiffusion import renderer
from easydiffusion import runtime
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
log.warn(f"Render thread on device: {runtime.context.device} failed to acquire manager lock.")
return None
if len(tasks_queue) <= 0:
manager_lock.release()
@ -242,7 +208,7 @@ def thread_get_next_task():
task = None
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.render_device and renderer.context.device != queued_task.render_device:
if queued_task.render_device and runtime.context.device != queued_task.render_device:
# Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one.
@ -251,7 +217,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + " is not currently active.")
task = queued_task
break
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
if not queued_task.render_device and runtime.context.device == "cpu" and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task
@ -266,19 +232,19 @@ def thread_get_next_task():
def thread_render(device):
global current_state, current_state_error
from easydiffusion import model_manager, renderer
from easydiffusion import model_manager, runtime
try:
renderer.init(device)
runtime.init(device)
weak_thread_data[threading.current_thread()] = {
"device": renderer.context.device,
"device_name": renderer.context.device_name,
"device": runtime.context.device,
"device_name": runtime.context.device_name,
"alive": True,
}
current_state = ServerStates.LoadingModel
model_manager.load_default_models(renderer.context)
model_manager.load_default_models(runtime.context)
current_state = ServerStates.Online
except Exception as e:
@ -290,8 +256,8 @@ def thread_render(device):
session_cache.clean()
task_cache.clean()
if not weak_thread_data[threading.current_thread()]["alive"]:
log.info(f"Shutting down thread for device {renderer.context.device}")
model_manager.unload_all(renderer.context)
log.info(f"Shutting down thread for device {runtime.context.device}")
model_manager.unload_all(runtime.context)
return
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
@ -311,62 +277,31 @@ def thread_render(device):
task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
log.info(f"Session {task.session_id} starting task {task.id} on {runtime.context.device_name}")
if not task.lock.acquire(blocking=False):
raise Exception("Got locked task from queue.")
try:
task.run()
def step_callback():
global current_state_error
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
if (
isinstance(current_state_error, SystemExit)
or isinstance(current_state_error, StopAsyncIteration)
or isinstance(task.error, StopAsyncIteration)
):
renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
model_manager.fail_if_models_did_not_load(renderer.context)
current_state = ServerStates.Rendering
task.response = renderer.make_images(
task.render_request,
task.task_data,
task.buffer_queue,
task.temp_images,
step_callback,
)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
keep_task_alive(task)
except Exception as e:
task.error = str(e)
task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
log.error(traceback.format_exc())
finally:
gc(renderer.context)
gc(runtime.context)
task.lock.release()
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
keep_task_alive(task)
if isinstance(task.error, StopAsyncIteration):
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
log.info(f"Session {task.session_id} task {task.id} cancelled!")
elif task.error is not None:
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
log.info(f"Session {task.session_id} task {task.id} failed!")
else:
log.info(
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
)
log.info(f"Session {task.session_id} task {task.id} completed by {runtime.context.device_name}.")
current_state = ServerStates.Online
@ -438,6 +373,12 @@ def get_devices():
finally:
manager_lock.release()
# temp until TRT releases
import os
from easydiffusion import app
devices["enable_trt"] = os.path.exists(os.path.join(app.ROOT_DIR, "tensorrt"))
return devices
@ -548,28 +489,27 @@ def shutdown_event(): # Signal render thread to close on shutdown
current_state_error = SystemExit("Application shutting down.")
def render(render_req: GenerateImageRequest, task_data: TaskData):
def enqueue_task(task: Task):
current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError("Rendering thread has died.")
# Alive, check if task in cache
session = get_cached_session(task_data.session_id, update_ttl=True)
session = get_cached_session(task.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks):
if len(pending_tasks) > current_thread_count * MAX_OVERLOAD_ALLOWED_RATIO:
raise ConnectionRefusedError(
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
f"Session {task.session_id} already has {len(pending_tasks)} pending tasks, with {current_thread_count} workers."
)
new_task = RenderTask(render_req, task_data)
if session.put(new_task, TASK_TTL):
if session.put(task, TASK_TTL):
# Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
tasks_queue.append(task)
idle_event.set()
return new_task
return task
finally:
manager_lock.release()
raise RuntimeError("Failed to add task to cache.")

View File

@ -0,0 +1,3 @@
from .task import Task
from .render_images import RenderTask
from .filter_images import FilterTask

View File

@ -0,0 +1,110 @@
import json
import pprint
from sdkit.filter import apply_filters
from sdkit.models import load_model
from sdkit.utils import img_to_base64_str, log
from easydiffusion import model_manager, runtime
from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData
from .task import Task
class FilterTask(Task):
"For applying filters to input images"
def __init__(
self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData
):
super().__init__(session_id)
self.request = req
self.models_data = models_data
self.output_format = output_format
# convert to multi-filter format, if necessary
if isinstance(req.filter, str):
req.filter_params = {req.filter: req.filter_params}
req.filter = [req.filter]
if not isinstance(req.image, list):
req.image = [req.image]
def run(self):
"Runs the image filtering task on the assigned thread"
context = runtime.context
model_manager.resolve_model_paths(self.models_data)
model_manager.reload_models_if_necessary(context, self.models_data)
model_manager.fail_if_models_did_not_load(context)
print_task_info(self.request, self.models_data, self.output_format)
images = filter_images(context, self.request.image, self.request.filter, self.request.filter_params)
output_format = self.output_format
images = [
img_to_base64_str(
img, output_format.output_format, output_format.output_quality, output_format.output_lossless
)
for img in images
]
res = FilterImageResponse(self.request, self.models_data, images=images)
res = res.json()
self.buffer_queue.put(json.dumps(res))
log.info("Filter task completed")
self.response = res
def filter_images(context, images, filters, filter_params={}):
filters = filters if isinstance(filters, list) else [filters]
for filter_name in filters:
params = filter_params.get(filter_name, {})
previous_state = before_filter(context, filter_name, params)
try:
images = apply_filters(context, filter_name, images, **params)
finally:
after_filter(context, filter_name, params, previous_state)
return images
def before_filter(context, filter_name, filter_params):
if filter_name == "codeformer":
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
upscale_faces = filter_params.get("upscale_faces", False)
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths.get("realesrgan")
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
return prev_realesrgan_path
def after_filter(context, filter_name, filter_params, previous_state):
if filter_name == "codeformer":
prev_realesrgan_path = previous_state
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData):
req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}")

View File

@ -0,0 +1,340 @@
import json
import pprint
import queue
import time
from easydiffusion import model_manager, runtime
from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData
from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit.generate import generate_images
from sdkit.utils import (
diffusers_latent_samples_to_images,
gc,
img_to_base64_str,
img_to_buffer,
latent_samples_to_images,
log,
)
from .task import Task
from .filter_images import filter_images
class RenderTask(Task):
"For image generation"
def __init__(
self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
):
super().__init__(task_data.session_id)
task_data.request_id = self.id
self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.models_data = models_data
self.output_format = output_format
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
def run(self):
"Runs the image generation task on the assigned thread"
from easydiffusion import task_manager
context = runtime.context
def step_callback():
task_manager.keep_task_alive(self)
task_manager.current_state = task_manager.ServerStates.Rendering
if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance(
self.error, StopAsyncIteration
):
context.stop_processing = True
if isinstance(task_manager.current_state_error, StopAsyncIteration):
self.error = task_manager.current_state_error
task_manager.current_state_error = None
log.info(f"Session {self.session_id} sent cancel signal for task {self.id}")
task_manager.current_state = task_manager.ServerStates.LoadingModel
model_manager.resolve_model_paths(self.models_data)
models_to_force_reload = []
if (
runtime.set_vram_optimizations(context)
or self.has_param_changed(context, "clip_skip")
or self.trt_needs_reload(context)
):
models_to_force_reload.append("stable-diffusion")
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
model_manager.fail_if_models_did_not_load(context)
task_manager.current_state = task_manager.ServerStates.Rendering
self.response = make_images(
context,
self.render_request,
self.task_data,
self.models_data,
self.output_format,
self.buffer_queue,
self.temp_images,
step_callback,
)
def has_param_changed(self, context, param_name):
if not context.test_diffusers:
return False
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
return True
model = context.models["stable-diffusion"]
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
return model["params"].get(param_name) != new_val
def trt_needs_reload(self, context):
if not context.test_diffusers:
return False
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
return True
model = context.models["stable-diffusion"]
# curr_convert_to_trt = model["params"].get("convert_to_tensorrt")
new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False)
pipe = model["default"]
is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr(
pipe.unet, "_allocate_trt_buffers_backup"
)
if new_convert_to_trt and not is_trt_loaded:
return True
curr_build_config = model["params"].get("trt_build_config")
new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {})
return new_convert_to_trt and curr_build_config != new_build_config
def make_images(
context,
req: GenerateImageRequest,
task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
):
context.stop_processing = False
print_task_info(req, task_data, models_data, output_format)
images, seeds = make_images_internal(
context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback
)
res = GenerateImageResponse(
req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format)
)
res = res.json()
data_queue.put(json.dumps(res))
log.info("Task completed")
return res
def print_task_info(
req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
):
req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"task data: {task_str}")
# log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}")
def make_images_internal(
context,
req: GenerateImageRequest,
task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
):
images, user_stopped = generate_images_internal(
context,
req,
task_data,
models_data,
data_queue,
task_temp_images,
step_callback,
task_data.stream_image_progress,
task_data.stream_image_progress_interval,
)
gc(context)
filters, filter_params = task_data.filters, task_data.filter_params
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data, output_format)
seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images:
return filtered_images, seeds
else:
return images + filtered_images, seeds + seeds
def generate_images_internal(
context,
req: GenerateImageRequest,
task_data: TaskData,
models_data: ModelsData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
):
context.temp_images.clear()
callback = make_step_callback(
context,
req,
task_data,
data_queue,
task_temp_images,
step_callback,
stream_image_progress,
stream_image_progress_interval,
)
try:
if req.init_image is not None and not context.test_diffusers:
req.sampler_name = "ddim"
req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8
if req.control_image and task_data.control_filter_to_apply:
req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0]
if context.test_diffusers:
pipe = context.models["stable-diffusion"]["default"]
if hasattr(pipe.unet, "_allocate_trt_buffers_backup"):
setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup)
delattr(pipe.unet, "_allocate_trt_buffers_backup")
if hasattr(pipe.unet, "_allocate_trt_buffers"):
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
if convert_to_trt:
pipe.unet.forward = pipe.unet._trt_forward
# pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward
log.info(f"Setting unet.forward to TensorRT")
else:
log.info(f"Not using TensorRT for unet.forward")
pipe.unet.forward = pipe.unet._non_trt_forward
# pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward
setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers)
delattr(pipe.unet, "_allocate_trt_buffers")
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if context.partial_x_samples is not None:
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
else:
images = latent_samples_to_images(context, context.partial_x_samples)
finally:
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
if not context.test_diffusers:
del context.partial_x_samples
context.partial_x_samples = None
return images, user_stopped
def construct_response(images: list, seeds: list, output_format: OutputFormatData):
return [
ResponseImage(
data=img_to_base64_str(
img,
output_format.output_format,
output_format.output_quality,
output_format.output_lossless,
),
seed=seed,
)
for img, seed in zip(images, seeds)
]
def make_step_callback(
context,
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
stream_image_progress: bool,
stream_image_progress_interval: int,
):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
if context.test_diffusers:
images = diffusers_latent_samples_to_images(context, x_samples)
else:
images = latent_samples_to_images(context, x_samples)
if task_data.block_nsfw:
images = filter_images(context, images, "nsfw_checker")
for i, img in enumerate(images):
buf = img_to_buffer(img, output_format="JPEG")
context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
del images
return partial_images
def on_image_step(x_samples, i, *args):
nonlocal last_callback_time
if context.test_diffusers:
context.partial_x_samples = (x_samples, args[0])
else:
context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
step_callback()
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

View File

@ -0,0 +1,47 @@
from threading import Lock
from queue import Queue, Empty as EmptyQueueException
from typing import Any
class Task:
"Task with output queue and completion lock"
def __init__(self, session_id):
self.id = id(self)
self.session_id = session_id
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.error: Exception = None
self.lock: Lock = Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: Queue = Queue() # Queue of JSON string segments
self.response: Any = None # Copy of the last reponse
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except EmptyQueueException as e:
yield
@property
def status(self):
if self.lock.locked():
return "running"
if isinstance(self.error, StopAsyncIteration):
return "stopped"
if self.error:
return "error"
if not self.buffer_queue.empty():
return "buffer"
if self.response:
return "completed"
return "pending"
@property
def is_pending(self):
return bool(not self.response and not self.error)
def run(self):
"Override this to implement the task's behavior"
pass

View File

@ -1,4 +1,4 @@
from typing import Any, List, Union
from typing import Any, List, Dict, Union
from pydantic import BaseModel
@ -17,8 +17,11 @@ class GenerateImageRequest(BaseModel):
init_image: Any = None
init_image_mask: Any = None
control_image: Any = None
control_alpha: Union[float, List[float]] = None
prompt_strength: float = 0.8
preserve_init_image_color_profile = False
strict_mask_border = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
hypernetwork_strength: float = 0
@ -26,6 +29,35 @@ class GenerateImageRequest(BaseModel):
tiling: str = "none" # "none", "x", "y", "xy"
class FilterImageRequest(BaseModel):
image: Any = None
filter: Union[str, List[str]] = None
filter_params: dict = {}
class ModelsData(BaseModel):
"""
Contains the information related to the models involved in a request.
- To load a model: set the relative path(s) to the model in `model_paths`. No effect if already loaded.
- To unload a model: set the model to `None` in `model_paths`. No effect if already unloaded.
Models that aren't present in `model_paths` will not be changed.
"""
model_paths: Dict[str, Union[str, None, List[str]]] = None
"model_type to string path, or list of string paths"
model_params: Dict[str, Dict[str, Any]] = {}
"model_type to dict of parameters"
class OutputFormatData(BaseModel):
output_format: str = "jpeg" # or "png" or "webp"
output_quality: int = 75
output_lossless: bool = False
class TaskData(BaseModel):
request_id: str = None
session_id: str = "session"
@ -40,12 +72,13 @@ class TaskData(BaseModel):
use_vae_model: Union[str, List[str]] = None
use_hypernetwork_model: Union[str, List[str]] = None
use_lora_model: Union[str, List[str]] = None
use_controlnet_model: Union[str, List[str]] = None
filters: List[str] = []
filter_params: Dict[str, Dict[str, Any]] = {}
control_filter_to_apply: Union[str, List[str]] = None
show_only_filtered_image: bool = False
block_nsfw: bool = False
output_format: str = "jpeg" # or "png" or "webp"
output_quality: int = 75
output_lossless: bool = False
metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False
stream_image_progress_interval: int = 5
@ -80,24 +113,39 @@ class Image:
}
class Response:
class GenerateImageResponse:
render_request: GenerateImageRequest
task_data: TaskData
models_data: ModelsData
images: list
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
def __init__(
self,
render_request: GenerateImageRequest,
task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
images: list,
):
self.render_request = render_request
self.task_data = task_data
self.models_data = models_data
self.output_format = output_format
self.images = images
def json(self):
del self.render_request.init_image
del self.render_request.init_image_mask
del self.render_request.control_image
task_data = self.task_data.dict()
task_data.update(self.output_format.dict())
res = {
"status": "succeeded",
"render_request": self.render_request.dict(),
"task_data": self.task_data.dict(),
"task_data": task_data,
# "models_data": self.models_data.dict(), # haven't migrated the UI to the new format (yet)
"output": [],
}
@ -107,5 +155,111 @@ class Response:
return res
class FilterImageResponse:
request: FilterImageRequest
models_data: ModelsData
images: list
def __init__(self, request: FilterImageRequest, models_data: ModelsData, images: list):
self.request = request
self.models_data = models_data
self.images = images
def json(self):
del self.request.image
res = {
"status": "succeeded",
"request": self.request.dict(),
"models_data": self.models_data.dict(),
"output": [],
}
for image in self.images:
res["output"].append(image)
return res
class UserInitiatedStop(Exception):
pass
def convert_legacy_render_req_to_new(old_req: dict):
new_req = dict(old_req)
# new keys
model_paths = new_req["model_paths"] = {}
model_params = new_req["model_params"] = {}
filters = new_req["filters"] = []
filter_params = new_req["filter_params"] = {}
# move the model info
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
model_paths["vae"] = old_req.get("use_vae_model")
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
model_paths["lora"] = old_req.get("use_lora_model")
model_paths["controlnet"] = old_req.get("use_controlnet_model")
model_paths["gfpgan"] = old_req.get("use_face_correction", "")
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None
model_paths["codeformer"] = old_req.get("use_face_correction", "")
model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None
model_paths["realesrgan"] = old_req.get("use_upscale", "")
model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None
model_paths["latent_upscaler"] = old_req.get("use_upscale", "")
model_paths["latent_upscaler"] = (
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
)
if "control_filter_to_apply" in old_req:
filter_model = old_req["control_filter_to_apply"]
model_paths[filter_model] = filter_model
if old_req.get("block_nsfw"):
model_paths["nsfw_checker"] = "nsfw_checker"
# move the model params
if model_paths["stable-diffusion"]:
model_params["stable-diffusion"] = {
"clip_skip": bool(old_req.get("clip_skip", False)),
"convert_to_tensorrt": bool(old_req.get("convert_to_tensorrt", False)),
"trt_build_config": old_req.get(
"trt_build_config", {"batch_size_range": (1, 1), "dimensions_range": [(768, 1024)]}
),
}
# move the filter params
if model_paths["realesrgan"]:
filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))}
if model_paths["latent_upscaler"]:
filter_params["latent_upscaler"] = {
"prompt": old_req["prompt"],
"negative_prompt": old_req.get("negative_prompt"),
"seed": int(old_req.get("seed", 42)),
"num_inference_steps": int(old_req.get("latent_upscaler_steps", 10)),
"guidance_scale": 0,
}
if model_paths["codeformer"]:
filter_params["codeformer"] = {
"upscale_faces": bool(old_req.get("codeformer_upscale_faces", True)),
"codeformer_fidelity": float(old_req.get("codeformer_fidelity", 0.5)),
}
# set the filters
if old_req.get("block_nsfw"):
filters.append("nsfw_checker")
if model_paths["codeformer"]:
filters.append("codeformer")
elif model_paths["gfpgan"]:
filters.append("gfpgan")
if model_paths["realesrgan"]:
filters.append("realesrgan")
elif model_paths["latent_upscaler"]:
filters.append("latent_upscaler")
return new_req

View File

@ -7,7 +7,7 @@ from datetime import datetime
from functools import reduce
from easydiffusion import app
from easydiffusion.types import GenerateImageRequest, TaskData
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
from numpy import base_repr
from sdkit.utils import save_dicts, save_images
@ -21,6 +21,8 @@ TASK_TEXT_MAPPING = {
"seed": "Seed",
"use_stable_diffusion_model": "Stable Diffusion model",
"clip_skip": "Clip Skip",
"use_controlnet_model": "ControlNet model",
"control_filter_to_apply": "ControlNet Filter",
"use_vae_model": "VAE model",
"sampler_name": "Sampler",
"width": "Width",
@ -114,12 +116,14 @@ def format_file_name(
return filename_regex.sub("_", format)
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
def save_images_to_disk(
images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData
):
now = time.time()
app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id")
save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
metadata_entries = get_metadata_entries_for_request(req, task_data)
metadata_entries = get_metadata_entries_for_request(req, task_data, output_format)
file_number = calculate_img_number(save_dir_path, task_data)
make_filename = make_filename_callback(
app_config.get("filename_format", "$p_$tsb64"),
@ -134,9 +138,9 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
filtered_images,
save_dir_path,
file_name=make_filename,
output_format=task_data.output_format,
output_quality=task_data.output_quality,
output_lossless=task_data.output_lossless,
output_format=output_format.output_format,
output_quality=output_format.output_quality,
output_lossless=output_format.output_lossless,
)
if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","):
@ -146,7 +150,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
save_dir_path,
file_name=make_filename,
output_format=metadata_output_format,
file_format=task_data.output_format,
file_format=output_format.output_format,
)
else:
make_filter_filename = make_filename_callback(
@ -162,17 +166,17 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
images,
save_dir_path,
file_name=make_filename,
output_format=task_data.output_format,
output_quality=task_data.output_quality,
output_lossless=task_data.output_lossless,
output_format=output_format.output_format,
output_quality=output_format.output_quality,
output_lossless=output_format.output_lossless,
)
save_images(
filtered_images,
save_dir_path,
file_name=make_filter_filename,
output_format=task_data.output_format,
output_quality=task_data.output_quality,
output_lossless=task_data.output_lossless,
output_format=output_format.output_format,
output_quality=output_format.output_quality,
output_lossless=output_format.output_lossless,
)
if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","):
@ -181,20 +185,21 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
metadata_entries,
save_dir_path,
file_name=make_filter_filename,
output_format=task_data.metadata_output_format,
file_format=task_data.output_format,
output_format=metadata_output_format,
file_format=output_format.output_format,
)
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req, task_data)
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
metadata = get_printable_request(req, task_data, output_format)
# if text, format it in the text format expected by the UI
is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",")
if is_txt_format:
def format_value(value):
if isinstance(value, list):
return ", ".join([ str(it) for it in value ])
return ", ".join([str(it) for it in value])
return value
metadata = {
@ -208,9 +213,10 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
return entries
def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
req_metadata = req.dict()
task_data_metadata = task_data.dict()
task_data_metadata.update(output_format.dict())
app_config = app.getConfig()
using_diffusers = app_config.get("test_diffusers", False)
@ -224,6 +230,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
metadata[key] = task_data_metadata[key]
elif key == "use_embedding_models" and using_diffusers:
embeddings_extensions = {".pt", ".bin", ".safetensors"}
def scan_directory(directory_path: str):
used_embeddings = []
for entry in os.scandir(directory_path):
@ -232,15 +239,18 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
if entry_extension not in embeddings_extensions:
continue
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])")
embedding_name_regex = regex.compile(
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
)
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
used_embeddings.append(entry.path)
elif entry.is_dir():
used_embeddings.extend(scan_directory(entry.path))
return used_embeddings
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
# Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata:
del metadata["prompt_strength"]
@ -252,9 +262,13 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
del metadata["lora_alpha"]
if task_data.use_upscale != "latent_upscaler" and "latent_upscaler_steps" in metadata:
del metadata["latent_upscaler_steps"]
if task_data.use_controlnet_model is None and "control_filter_to_apply" in metadata:
del metadata["control_filter_to_apply"]
if not using_diffusers:
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata):
for key in (
x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps", "use_controlnet_model", "control_filter_to_apply"] if x in metadata
):
del metadata[key]
return metadata

View File

@ -17,6 +17,7 @@
<link rel="stylesheet" href="/media/css/searchable-models.css">
<link rel="stylesheet" href="/media/css/image-modal.css">
<link rel="stylesheet" href="/media/css/plugins.css">
<link rel="stylesheet" href="/media/css/animations.css">
<link rel="manifest" href="/media/manifest.webmanifest">
<script src="/media/js/jquery-3.6.1.min.js"></script>
<script src="/media/js/jquery-confirm.min.js"></script>
@ -31,7 +32,7 @@
<h1>
<img id="logo_img" src="/media/images/icon-512x512.png" >
Easy Diffusion
<small><span id="version">v2.5.46</span> <span id="updateBranchLabel"></span></small>
<small><span id="version">v2.5.48</span> <span id="updateBranchLabel"></span></small>
</h1>
</div>
<div id="server-status">
@ -82,8 +83,8 @@
<label for="init_image">Initial Image (img2img) <small>(optional)</small> </label>
<div id="init_image_preview_container" class="image_preview_container">
<div id="init_image_wrapper">
<img id="init_image_preview" src="" crossorigin="anonymous" />
<div id="init_image_wrapper" class="preview_image_wrapper">
<img id="init_image_preview" class="image_preview" src="" crossorigin="anonymous" />
<span id="init_image_size_box" class="img_bottom_label"></span>
<button class="init_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
</div>
@ -108,6 +109,7 @@
</div>
<div id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Preserve color profile <small>(helps during inpainting)</small></label></div>
<div id="strict_mask_border_setting" class="pl-5"><input id="strict_mask_border" name="strict_mask_border" type="checkbox"> <label for="strict_mask_border">Strict Mask Border <small>(won't modify outside the mask, but the mask border might be visible)</small></label></div>
</div>
@ -139,12 +141,19 @@
<div><table>
<tr><b class="settings-subheader">Image Settings</b></tr>
<tr class="pl-5"><td><label for="seed">Seed:</label></td><td><input id="seed" name="seed" size="10" value="0" onkeypress="preventNonNumericalInput(event)"> <input id="random_seed" name="random_seed" type="checkbox" checked><label for="random_seed">Random</label></td></tr>
<tr class="pl-5"><td><label for="num_outputs_total">Number of Images:</label></td><td><input id="num_outputs_total" name="num_outputs_total" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label><small>(total)</small></label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label for="num_outputs_parallel"><small>(in parallel)</small></label></td></tr>
<tr class="pl-5"><td><label for="num_outputs_total">Number of Images:</label></td><td><input id="num_outputs_total" name="num_outputs_total" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label><small>(total)</small></label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label id="num_outputs_parallel_label" for="num_outputs_parallel"><small>(in parallel)</small></label></td></tr>
<tr class="pl-5"><td><label for="stable_diffusion_model">Model:</label></td><td class="model-input">
<input id="stable_diffusion_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
<button id="reload-models" class="secondaryButton reloadModels"><i class='fa-solid fa-rotate'></i></button>
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
</td></tr>
<tr class="pl-5 displayNone" id="enable_trt_config">
<td><label for="convert_to_tensorrt">Enable TensorRT:</label></td>
<td class="diffusers-restart-needed">
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
<!-- <label><small>Takes upto 20 mins the first time</small></label> -->
</td>
</tr>
<tr class="pl-5 displayNone" id="clip_skip_config">
<td><label for="clip_skip">Clip Skip:</label></td>
<td class="diffusers-restart-needed">
@ -152,6 +161,60 @@
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Clip-Skip" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about Clip Skip</span></i></a>
</td>
</tr>
<tr id="controlnet_model_container" class="pl-5"><td><label for="controlnet_model">ControlNet Image:</label></td><td>
<div id="control_image_wrapper" class="preview_image_wrapper">
<img id="control_image_preview" class="image_preview" src="" crossorigin="anonymous" />
<span id="control_image_size_box" class="img_bottom_label"></span>
<button class="control_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
</div>
<input id="control_image" name="control_image" type="file" />
<a href="https://github.com/easydiffusion/easydiffusion/wiki/ControlNet" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about ControlNets</span></i></a>
<div id="controlnet_config" class="displayNone">
<label><small>Filter to apply:</small></label>
<select id="control_image_filter">
<option value="">None</option>
<optgroup label="Pose">
<option value="openpose">OpenPose (*)</option>
<option value="openpose_face">OpenPose face</option>
<option value="openpose_faceonly">OpenPose face-only</option>
<option value="openpose_hand">OpenPose hand</option>
<option value="openpose_full">OpenPose full</option>
</optgroup>
<optgroup label="Outline">
<option value="canny">Canny (*)</option>
<option value="mlsd">Straight lines</option>
<option value="scribble_hed">Scribble hed (*)</option>
<option value="scribble_hedsafe">Scribble hedsafe</option>
<option value="scribble_pidinet">Scribble pidinet</option>
<option value="scribble_pidsafe">Scribble pidsafe</option>
<option value="softedge_hed">Softedge hed</option>
<option value="softedge_hedsafe">Softedge hedsafe</option>
<option value="softedge_pidinet">Softedge pidinet</option>
<option value="softedge_pidsafe">Softedge pidsafe</option>
</optgroup>
<optgroup label="Depth">
<option value="normal_bae">Normal bae (*)</option>
<option value="depth_midas">Depth midas</option>
<option value="depth_zoe">Depth zoe</option>
<option value="depth_leres">Depth leres</option>
<option value="depth_leres++">Depth leres++</option>
</optgroup>
<optgroup label="Line art">
<option value="lineart_coarse">Lineart coarse</option>
<option value="lineart_realistic">Lineart realistic</option>
<option value="lineart_anime">Lineart anime</option>
</optgroup>
<optgroup label="Misc">
<option value="shuffle">Shuffle</option>
<option value="segment">Segment</option>
</optgroup>
</select>
<br/>
<label for="controlnet_model"><small>Model:</small></label> <input id="controlnet_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
<br/>
<label><small>Will download the necessary models, the first time.</small></label>
</div>
</td></tr>
<tr class="pl-5"><td><label for="vae_model">Custom VAE:</label></td><td>
<input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
<a href="https://github.com/easydiffusion/easydiffusion/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
@ -182,15 +245,15 @@
</select>
<a href="https://github.com/easydiffusion/easydiffusion/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
</td></tr>
<tr class="pl-5"><td><label>Image Size: </label></td><td>
<tr class="pl-5"><td><label>Image Size: </label></td><td id="image-size-options">
<select id="width" name="width" value="512">
<option value="128">128 (*)</option>
<option value="128">128</option>
<option value="192">192</option>
<option value="256">256 (*)</option>
<option value="256">256</option>
<option value="320">320</option>
<option value="384">384</option>
<option value="448">448</option>
<option value="512" selected>512 (*)</option>
<option value="512" selected="">512 (*)</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
@ -205,14 +268,15 @@
<option value="2048">2048</option>
</select>
<label for="width"><small>(width)</small></label>
<span id="swap-width-height" class="clickable smallButton" style="margin-left: 2px; margin-right:2px;"><i class="fa-solid fa-right-left"><span class="simple-tooltip top-left"> Swap width and height </span></i></span>
<select id="height" name="height" value="512">
<option value="128">128 (*)</option>
<option value="128">128</option>
<option value="192">192</option>
<option value="256">256 (*)</option>
<option value="256">256</option>
<option value="320">320</option>
<option value="384">384</option>
<option value="448">448</option>
<option value="512" selected>512 (*)</option>
<option value="512" selected="">512 (*)</option>
<option value="576">576</option>
<option value="640">640</option>
<option value="704">704</option>
@ -227,6 +291,22 @@
<option value="2048">2048</option>
</select>
<label for="height"><small>(height)</small></label>
<div id="recent-resolutions-container">
<span id="recent-resolutions-button" class="clickable"><i class="fa-solid fa-sliders"><span class="simple-tooltip top-left"> Advanced sizes </span></i></span>
<div id="recent-resolutions-popup" class="displayNone">
<small>Custom size:</small><br>
<input id="custom-width" name="custom-width" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)">
&times;
<input id="custom-height" name="custom-height" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)"><br>
<small>Enlarge:</small><br>
<div id="enlarge-buttons"><button id="enlarge15" class="tertiaryButton smallButton">×1.5</button>&nbsp;<button id="enlarge2" class="tertiaryButton smallButton">×2</button>&nbsp;<button id="enlarge3" class="tertiaryButton smallButton">×3</button></div>
<small>Recently&nbsp;used:</small><br>
<div id="recent-resolution-list">
</div>
</div>
</div>
<div id="small_image_warning" class="displayNone">Small image sizes can cause bad image quality</div>
</td></tr>
<tr class="pl-5"><td><label for="num_inference_steps">Inference Steps:</label></td><td> <input id="num_inference_steps" name="num_inference_steps" type="number" min="1" step="1" style="width: 42pt" value="25" onkeypress="preventNonNumericalInput(event)"></td></tr>
@ -359,6 +439,13 @@
<div class="parameters-table" id="system-settings-table"></div>
<br/>
<button id="save-system-settings-btn" class="primaryButton">Save</button>
<div id="install-extras-container" class="displayNone">
<br/>
<div id="install-extras">
<h3><i class="fa fa-cubes-stacked"></i> Optional Packages</h3>
<div class="parameters-table" id="system-settings-install-extras-table"></div>
</div>
</div>
<br/><br/>
<div id="share-easy-diffusion">
<h3><i class="fa fa-user-group"></i> Share Easy Diffusion</h3>
@ -669,10 +756,10 @@ async function init() {
events: {
statusChange: setServerStatus,
idle: onIdle,
ping: tunnelUpdate
ping: onPing
}
})
splashScreen()
// splashScreen()
// load models again, but scan for malicious this time
await getModels(true)

View File

@ -0,0 +1,68 @@
@keyframes ldio-8f673ktaleu-1 {
0% { transform: rotate(0deg) }
50% { transform: rotate(-45deg) }
100% { transform: rotate(0deg) }
}
@keyframes ldio-8f673ktaleu-2 {
0% { transform: rotate(180deg) }
50% { transform: rotate(225deg) }
100% { transform: rotate(180deg) }
}
.ldio-8f673ktaleu > div:nth-child(2) {
transform: translate(-15px,0);
}
.ldio-8f673ktaleu > div:nth-child(2) div {
position: absolute;
top: 20px;
left: 20px;
width: 60px;
height: 30px;
border-radius: 60px 60px 0 0;
background: #f3b72e;
animation: ldio-8f673ktaleu-1 1s linear infinite;
transform-origin: 30px 30px
}
.ldio-8f673ktaleu > div:nth-child(2) div:nth-child(2) {
animation: ldio-8f673ktaleu-2 1s linear infinite
}
.ldio-8f673ktaleu > div:nth-child(2) div:nth-child(3) {
transform: rotate(-90deg);
animation: none;
}@keyframes ldio-8f673ktaleu-3 {
0% { transform: translate(95px,0); opacity: 0 }
20% { opacity: 1 }
100% { transform: translate(35px,0); opacity: 1 }
}
.ldio-8f673ktaleu > div:nth-child(1) {
display: block;
}
.ldio-8f673ktaleu > div:nth-child(1) div {
position: absolute;
top: 46px;
left: -4px;
width: 8px;
height: 8px;
border-radius: 50%;
background: #3869c5;
animation: ldio-8f673ktaleu-3 1s linear infinite
}
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(1) { animation-delay: -0.67s }
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(2) { animation-delay: -0.33s }
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(3) { animation-delay: 0s }
.loadingio-spinner-bean-eater-x0y3u8qky4n {
width: 58px;
height: 58px;
display: inline-block;
overflow: hidden;
background: none;
}
.ldio-8f673ktaleu {
width: 100%;
height: 100%;
position: relative;
transform: translateZ(0) scale(0.58);
backface-visibility: hidden;
transform-origin: 0 0; /* see note above */
}
.ldio-8f673ktaleu div { box-sizing: content-box; }
/* generated by https://loading.io/ */

View File

@ -794,7 +794,7 @@ div.img-preview img {
margin-bottom: 8px;
}
#init_image_preview_container:not(.has-image) #init_image_wrapper,
#init_image_preview_container:not(.has-image) .preview_image_wrapper,
#init_image_preview_container:not(.has-image) #inpaint_button_container {
display: none;
}
@ -831,14 +831,14 @@ div.img-preview img {
gap: 8px;
}
#init_image_wrapper {
.preview_image_wrapper {
grid-row: span 3;
position: relative;
width: fit-content;
max-height: 150px;
}
#init_image_preview {
.image_preview {
max-height: 150px;
height: 100%;
width: 100%;
@ -1753,3 +1753,82 @@ body.wait-pause {
content: "Please restart Easy Diffusion!";
font-size: 10pt;
}
input#custom-width, input#custom-height {
width: 47pt;
}
div#recent-resolutions-container {
position: relative;
display:inline-block;
}
div#recent-resolutions-popup {
position: absolute;
right: 0px;
margin: 3px;
padding: 0.2em 1em 0.4em 1em;
z-index: 1;
background: var(--background-color3);
border-radius: 4px;
box-shadow: 0 20px 28px 0 rgba(0, 0, 0, 0.15), 0 6px 20px 0 rgba(0, 0, 0, 0.15);
}
div#recent-resolutions-popup small {
opacity: 0.7;
}
td#image-size-options small {
margin-right: 0px !important;
}
td#image-size-options {
white-space: nowrap;
}
div#recent-resolution-list {
text-align: center;
}
div#enlarge-buttons {
text-align: center;
}
.clickable {
cursor: pointer;
}
.imgContainer .spinner {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
margin: 0;
padding: 0;
background: var(--background-color3);
opacity: 0.95;
border-radius: 5px;
padding: 4pt;
border: 1px solid var(--button-color);
box-shadow: 0px 0px 4px black;
}
.imgContainer .spinnerStatus {
font-size: 10pt;
}
#controlnet_model_container small {
color: var(--text-color)
}
#control_image {
width: 130pt;
}
#controlnet_model {
width: 77%;
}
/* hack for fixing Image Modifier Improvements plugin */
#imageTagPopupContainer {
position: absolute;
}

View File

@ -1047,7 +1047,9 @@
}
}
class FilterTask extends Task {
constructor(options = {}) {}
constructor(options = {}) {
super(options)
}
/** Send current task to server.
* @param {*} [timeout=-1] Optional timeout value in ms
* @returns the response from the render request.
@ -1055,9 +1057,27 @@
*/
async post(timeout = -1) {
let jsonResponse = await super.post("/filter", timeout)
//this._setId(jsonResponse.task)
if (typeof jsonResponse?.task !== "number") {
console.warn("Endpoint error response: ", jsonResponse)
const event = Object.assign({ task: this }, jsonResponse)
await eventSource.fireEvent(EVENT_UNEXPECTED_RESPONSE, event)
if ("continueWith" in event) {
jsonResponse = await Promise.resolve(event.continueWith)
}
if (typeof jsonResponse?.task !== "number") {
const err = new Error(jsonResponse?.detail || "Endpoint response does not contains a task ID.")
this.abort(err)
throw err
}
}
this._setId(jsonResponse.task)
if (jsonResponse.stream) {
this.streamUrl = jsonResponse.stream
}
this._setStatus(TaskStatus.waiting)
return jsonResponse
}
checkReqBody() {}
enqueue(progressCallback) {
return Task.enqueueNew(this, FilterTask, progressCallback)
}
@ -1068,6 +1088,65 @@
if (this.isStopped) {
return
}
this._setStatus(TaskStatus.pending)
progressCallback?.call(this, { reqBody: this._reqBody })
Object.freeze(this._reqBody)
// Post task request to backend
let renderRes = undefined
try {
renderRes = yield this.post()
yield progressCallback?.call(this, { renderResponse: renderRes })
} catch (e) {
yield progressCallback?.call(this, { detail: e.message })
throw e
}
try {
// Wait for task to start on server.
yield this.waitUntil({
callback: function() {
return progressCallback?.call(this, {})
},
status: TaskStatus.processing,
})
} catch (e) {
this.abort(err)
throw e
}
// Task started!
// Open the reader.
const reader = this.reader
const task = this
reader.onError = function(response) {
if (progressCallback) {
task.abort(new Error(response.statusText))
return progressCallback.call(task, { response, reader })
}
return Task.prototype.onError.call(task, response)
}
yield progressCallback?.call(this, { reader })
//Start streaming the results.
const streamGenerator = reader.open()
let value = undefined
let done = undefined
yield progressCallback?.call(this, { stream: streamGenerator })
do {
;({ value, done } = yield streamGenerator.next())
if (typeof value !== "object") {
continue
}
if (value.status !== undefined) {
yield progressCallback?.call(this, value)
if (value.status === "succeeded" || value.status === "failed") {
done = true
}
}
} while (!done)
return value
}
static start(task, progressCallback) {
if (typeof task !== "object") {

View File

@ -626,6 +626,7 @@ class ImageEditor {
.getImageData(0, 0, this.width, this.height)
.data.some((channel) => channel !== 0)
maskSetting.checked = !is_blank
maskSetting.dispatchEvent(new Event("change"))
}
this.hide()
}

View File

@ -5,6 +5,9 @@ const MIN_GPUS_TO_SHOW_SELECTION = 2
const IMAGE_REGEX = new RegExp("data:image/[A-Za-z]+;base64")
const htmlTaskMap = new WeakMap()
const spinnerPacmanHtml =
'<div class="loadingio-spinner-bean-eater-x0y3u8qky4n"><div class="ldio-8f673ktaleu"><div><div></div><div></div><div></div></div><div><div></div><div></div><div></div></div></div></div>'
const taskConfigSetup = {
taskConfig: {
seed: { value: ({ seed }) => seed, label: "Seed" },
@ -46,6 +49,7 @@ const taskConfigSetup = {
use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
preserve_init_image_color_profile: "Preserve Color Profile",
strict_mask_border: "Strict Mask Border",
},
pluginTaskConfig: {},
getCSSKey: (key) =>
@ -74,14 +78,30 @@ let randomSeedField = document.querySelector("#random_seed")
let seedField = document.querySelector("#seed")
let widthField = document.querySelector("#width")
let heightField = document.querySelector("#height")
let customWidthField = document.querySelector("#custom-width")
let customHeightField = document.querySelector("#custom-height")
let recentResolutionsButton = document.querySelector("#recent-resolutions-button")
let recentResolutionsPopup = document.querySelector("#recent-resolutions-popup")
let recentResolutionList = document.querySelector("#recent-resolution-list")
let enlarge15Button = document.querySelector("#enlarge15")
let enlarge2Button = document.querySelector("#enlarge2")
let enlarge3Button = document.querySelector("#enlarge3")
let swapWidthHeightButton = document.querySelector("#swap-width-height")
let smallImageWarning = document.querySelector("#small_image_warning")
let initImageSelector = document.querySelector("#init_image")
let initImagePreview = document.querySelector("#init_image_preview")
let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview")
let controlImageSelector = document.querySelector("#control_image")
let controlImagePreview = document.querySelector("#control_image_preview")
let controlImageClearBtn = document.querySelector(".control_image_clear")
let controlImageContainer = document.querySelector("#control_image_wrapper")
let controlImageFilterField = document.querySelector("#control_image_filter")
let applyColorCorrectionField = document.querySelector("#apply_color_correction")
let strictMaskBorderField = document.querySelector("#strict_mask_border")
let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting")
let strictMaskBorderSetting = document.querySelector("#strict_mask_border_setting")
let promptStrengthSlider = document.querySelector("#prompt_strength_slider")
let promptStrengthField = document.querySelector("#prompt_strength")
let samplerField = document.querySelector("#sampler_name")
@ -99,6 +119,7 @@ let codeformerFidelityField = document.querySelector("#codeformer_fidelity")
let stableDiffusionModelField = new ModelDropdown(document.querySelector("#stable_diffusion_model"), "stable-diffusion")
let clipSkipField = document.querySelector("#clip_skip")
let tilingField = document.querySelector("#tiling")
let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false)
let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None")
let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None")
let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider")
@ -160,6 +181,7 @@ let imagePreviewContent = document.querySelector("#preview-content")
let undoButton = document.querySelector("#undo")
let undoBuffer = []
const UNDO_LIMIT = 20
const MAX_IMG_UNDO_ENTRIES = 5
let loraModels = []
@ -271,24 +293,24 @@ function setServerStatus(event) {
// e : MouseEvent
// prompt : Text to be shown as prompt. Should be a question to which "yes" is a good answer.
// fn : function to be called if the user confirms the dialog or has the shift key pressed
// allowSkip: Allow skipping the dialog using the shift key or the confirm_dangerous_actions setting (default: true)
//
// If the user had the shift key pressed while clicking, the function fn will be executed.
// If the setting "confirm_dangerous_actions" in the system settings is disabled, the function
// fn will be executed.
// Otherwise, a confirmation dialog is shown. If the user confirms, the function fn will also
// be executed.
function shiftOrConfirm(e, prompt, fn) {
function shiftOrConfirm(e, prompt, fn, allowSkip = true) {
e.stopPropagation()
if (e.shiftKey || !confirmDangerousActionsField.checked) {
let tip = allowSkip
? '<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>'
: ""
if (allowSkip && (e.shiftKey || !confirmDangerousActionsField.checked)) {
fn(e)
} else {
confirm(
'<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>',
prompt,
() => {
fn(e)
}
)
confirm(tip, prompt, () => {
fn(e)
})
}
}
@ -412,6 +434,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
</div>
<button class="imgPreviewItemClearBtn image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
<span class="img_bottom_label"></span>
<div class="spinner displayNone"><center>${spinnerPacmanHtml}</center><div class="spinnerStatus"></div></div>
</div>
`
outputContainer.appendChild(imageItemElem)
@ -488,6 +511,8 @@ function showImages(reqBody, res, outputContainer, livePreview) {
const imageSeedLabel = imageItemElem.querySelector(".imgSeedLabel")
imageSeedLabel.innerText = "Seed: " + req.seed
const imageUndoBuffer = []
const imageRedoBuffer = []
let buttons = [
{ text: "Use as Input", on_click: onUseAsInputClick },
[
@ -505,8 +530,10 @@ function showImages(reqBody, res, outputContainer, livePreview) {
{ text: "Make Similar Images", on_click: onMakeSimilarClick },
{ text: "Draw another 25 steps", on_click: onContinueDrawingClick },
[
{ text: "Upscale", on_click: onUpscaleClick, filter: (req, img) => !req.use_upscale },
{ text: "Fix Faces", on_click: onFixFacesClick, filter: (req, img) => !req.use_face_correction },
{ html: '<i class="fa-solid fa-undo"></i> Undo', on_click: onUndoFilter },
{ html: '<i class="fa-solid fa-redo"></i> Redo', on_click: onRedoFilter },
{ text: "Upscale", on_click: onUpscaleClick },
{ text: "Fix Faces", on_click: onFixFacesClick },
],
]
@ -515,6 +542,14 @@ function showImages(reqBody, res, outputContainer, livePreview) {
const imgItemInfo = imageItemElem.querySelector(".imgItemInfo")
const img = imageItemElem.querySelector("img")
const spinner = imageItemElem.querySelector(".spinner")
const spinnerStatus = imageItemElem.querySelector(".spinnerStatus")
const tools = {
spinner: spinner,
spinnerStatus: spinnerStatus,
undoBuffer: imageUndoBuffer,
redoBuffer: imageRedoBuffer,
}
const createButton = function(btnInfo) {
if (Array.isArray(btnInfo)) {
const wrapper = document.createElement("div")
@ -540,8 +575,16 @@ function showImages(reqBody, res, outputContainer, livePreview) {
if (btnInfo.on_click || !isLabel) {
newButton.addEventListener("click", function(event) {
btnInfo.on_click(req, img, event)
btnInfo.on_click.bind(newButton)(req, img, event, tools)
})
if (btnInfo.on_click === onUndoFilter) {
tools["undoButton"] = newButton
newButton.classList.add("displayNone")
}
if (btnInfo.on_click === onRedoFilter) {
tools["redoButton"] = newButton
newButton.classList.add("displayNone")
}
}
if (btnInfo.class !== undefined) {
@ -656,16 +699,86 @@ function enqueueImageVariationTask(req, img, reqDiff) {
createTask(newTaskRequest)
}
function onUpscaleClick(req, img) {
enqueueImageVariationTask(req, img, {
use_upscale: upscaleModelField.value,
function applyInlineFilter(filterName, path, filterParams, img, statusText, tools) {
const filterReq = {
image: img.src,
filter: filterName,
model_paths: {},
filter_params: filterParams,
output_format: outputFormatField.value,
output_quality: parseInt(outputQualityField.value),
output_lossless: outputLosslessField.checked,
}
filterReq.model_paths[filterName] = path
tools.spinnerStatus.innerText = statusText
tools.spinner.classList.remove("displayNone")
SD.filter(filterReq, (e) => {
if (e.status === "succeeded") {
let prevImg = img.src
img.src = e.output[0]
tools.spinner.classList.add("displayNone")
if (prevImg.length > 0) {
tools.undoBuffer.push(prevImg)
tools.redoBuffer = []
if (tools.undoBuffer.length > MAX_IMG_UNDO_ENTRIES) {
let n = tools.undoBuffer.length
tools.undoBuffer.splice(0, n - MAX_IMG_UNDO_ENTRIES)
}
tools.undoButton.classList.remove("displayNone")
tools.redoButton.classList.add("displayNone")
}
} else if (e.status == "failed") {
alert("Error running upscale: " + e.detail)
tools.spinner.classList.add("displayNone")
}
})
}
function onFixFacesClick(req, img) {
enqueueImageVariationTask(req, img, {
use_face_correction: gfpganModelField.value,
})
function moveImageBetweenBuffers(img, fromBuffer, toBuffer, fromButton, toButton) {
if (fromBuffer.length === 0) {
return
}
let src = fromBuffer.pop()
if (src.length > 0) {
toBuffer.push(img.src)
img.src = src
}
if (fromBuffer.length === 0) {
fromButton.classList.add("displayNone")
}
if (toBuffer.length > 0) {
toButton.classList.remove("displayNone")
}
}
function onUndoFilter(req, img, e, tools) {
moveImageBetweenBuffers(img, tools.undoBuffer, tools.redoBuffer, tools.undoButton, tools.redoButton)
}
function onRedoFilter(req, img, e, tools) {
moveImageBetweenBuffers(img, tools.redoBuffer, tools.undoBuffer, tools.redoButton, tools.undoButton)
}
function onUpscaleClick(req, img, e, tools) {
let path = upscaleModelField.value
let scale = parseInt(upscaleAmountField.value)
let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler"
let statusText = "Upscaling by " + scale + "x using " + filterName
applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools)
}
function onFixFacesClick(req, img, e, tools) {
let path = gfpganModelField.value
let filterName = path.toLowerCase().includes("gfpgan") ? "gfpgan" : "codeformer"
let statusText = "Fixing faces with " + filterName
applyInlineFilter(filterName, path, {}, img, statusText, tools)
}
function onContinueDrawingClick(req, img) {
@ -909,7 +1022,9 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
<a href="https://www.ibm.com/docs/en/opw/8.2.0?topic=tuning-optional-increasing-paging-file-size-windows-computers" target="_blank">Windows</a> or
<a href="https://linuxhint.com/increase-swap-space-linux/" target="_blank">Linux</a>.<br/>
3. Try restarting your computer.<br/>`
} else if (msg.includes("RuntimeError: output with shape [320, 320] doesn't match the broadcast shape")) {
} else if (
msg.includes("RuntimeError: output with shape [320, 320] doesn't match the broadcast shape")
) {
msg += `<br/><br/>
<b>Reason</b>: You tried to use a LORA that was trained for a different Stable Diffusion model version!
<br/><br/>
@ -1237,9 +1352,25 @@ function createTask(task) {
function getCurrentUserRequest() {
const numOutputsTotal = parseInt(numOutputsTotalField.value)
const numOutputsParallel = parseInt(numOutputsParallelField.value)
let numOutputsParallel = parseInt(numOutputsParallelField.value)
const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value)
// if (
// testDiffusers.checked &&
// document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall" &&
// document.querySelector("#convert_to_tensorrt").checked
// ) {
// // TRT enabled
// numOutputsParallel = 1 // force 1 parallel
// }
// clamp to multiple of 8
let width = parseInt(widthField.value)
let height = parseInt(heightField.value)
width = width - (width % 8)
height = height - (height % 8)
const newTask = {
batchesDone: 0,
numOutputsTotal: numOutputsTotal,
@ -1252,8 +1383,8 @@ function getCurrentUserRequest() {
num_outputs: numOutputsParallel,
num_inference_steps: parseInt(numInferenceStepsField.value),
guidance_scale: parseFloat(guidanceScaleField.value),
width: parseInt(widthField.value),
height: parseInt(heightField.value),
width: width,
height: height,
// allow_nsfw: allowNSFWField.checked,
vram_usage_level: vramUsageLevelField.value,
sampler_name: samplerField.value,
@ -1283,6 +1414,7 @@ function getCurrentUserRequest() {
// }
if (maskSetting.checked) {
newTask.reqBody.mask = imageInpainter.getImg()
newTask.reqBody.strict_mask_border = strictMaskBorderField.checked
}
newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked
if (!testDiffusers.checked) {
@ -1323,6 +1455,34 @@ function getCurrentUserRequest() {
newTask.reqBody.lora_alpha = modelStrengths
}
}
if (testDiffusers.checked && document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
// TRT is installed
newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked
let trtBuildConfig = {
batch_size_range: [
parseInt(document.querySelector("#trt-build-min-batch").value),
parseInt(document.querySelector("#trt-build-max-batch").value),
],
dimensions_range: [],
}
let sizes = [512, 768, 1024, 1280, 1536]
sizes.forEach((i) => {
let el = document.querySelector("#trt-build-res-" + i)
if (el.checked) {
trtBuildConfig["dimensions_range"].push([i, i + 256])
}
})
newTask.reqBody.trt_build_config = trtBuildConfig
}
if (controlnetModelField.value !== "" && IMAGE_REGEX.test(controlImagePreview.src)) {
newTask.reqBody.use_controlnet_model = controlnetModelField.value
newTask.reqBody.control_image = controlImagePreview.src
if (controlImageFilterField.value !== "") {
newTask.reqBody.control_filter_to_apply = controlImageFilterField.value
}
}
return newTask
}
@ -1728,6 +1888,51 @@ function onFixFaceModelChange() {
gfpganModelField.addEventListener("change", onFixFaceModelChange)
onFixFaceModelChange()
function onControlnetModelChange() {
let configBox = document.querySelector("#controlnet_config")
if (IMAGE_REGEX.test(controlImagePreview.src)) {
configBox.classList.remove("displayNone")
controlImageContainer.classList.remove("displayNone")
} else {
configBox.classList.add("displayNone")
controlImageContainer.classList.add("displayNone")
}
}
controlImagePreview.addEventListener("load", onControlnetModelChange)
controlImagePreview.addEventListener("unload", onControlnetModelChange)
onControlnetModelChange()
function onControlImageFilterChange() {
let filterId = controlImageFilterField.value
if (filterId.includes("openpose")) {
controlnetModelField.value = "control_v11p_sd15_openpose"
} else if (filterId === "canny") {
controlnetModelField.value = "control_v11p_sd15_canny"
} else if (filterId === "mlsd") {
controlnetModelField.value = "control_v11p_sd15_mlsd"
} else if (filterId === "mlsd") {
controlnetModelField.value = "control_v11p_sd15_mlsd"
} else if (filterId.includes("scribble")) {
controlnetModelField.value = "control_v11p_sd15_scribble"
} else if (filterId.includes("softedge")) {
controlnetModelField.value = "control_v11p_sd15_softedge"
} else if (filterId === "normal_bae") {
controlnetModelField.value = "control_v11p_sd15_normalbae"
} else if (filterId.includes("depth")) {
controlnetModelField.value = "control_v11f1p_sd15_depth"
} else if (filterId === "lineart_anime") {
controlnetModelField.value = "control_v11p_sd15s2_lineart_anime"
} else if (filterId.includes("lineart")) {
controlnetModelField.value = "control_v11p_sd15_lineart"
} else if (filterId === "shuffle") {
controlnetModelField.value = "control_v11e_sd15_shuffle"
} else if (filterId === "segment") {
controlnetModelField.value = "control_v11p_sd15_seg"
}
}
controlImageFilterField.addEventListener("change", onControlImageFilterChange)
onControlImageFilterChange()
upscaleModelField.disabled = !useUpscalingField.checked
upscaleAmountField.disabled = !useUpscalingField.checked
useUpscalingField.addEventListener("change", function(e) {
@ -1957,6 +2162,7 @@ function checkRandomSeed() {
randomSeedField.addEventListener("input", checkRandomSeed)
checkRandomSeed()
// warning: the core plugin `image-editor-improvements.js:172` replaces loadImg2ImgFromFile() with a custom version
function loadImg2ImgFromFile() {
if (initImageSelector.files.length === 0) {
return
@ -1983,6 +2189,7 @@ function img2imgLoad() {
}
initImagePreviewContainer.classList.add("has-image")
colorCorrectionSetting.style.display = ""
strictMaskBorderSetting.style.display = maskSetting.checked ? "" : "none"
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
@ -2000,6 +2207,7 @@ function img2imgUnload() {
}
initImagePreviewContainer.classList.remove("has-image")
colorCorrectionSetting.style.display = "none"
strictMaskBorderSetting.style.display = "none"
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
}
initImagePreview.addEventListener("load", img2imgLoad)
@ -2008,11 +2216,55 @@ initImageClearBtn.addEventListener("click", img2imgUnload)
maskSetting.addEventListener("click", function() {
onDimensionChange()
})
maskSetting.addEventListener("change", function() {
strictMaskBorderSetting.style.display = this.checked ? "" : "none"
})
promptsFromFileBtn.addEventListener("click", function() {
promptsFromFileSelector.click()
})
function loadControlnetImageFromFile() {
if (controlImageSelector.files.length === 0) {
return
}
let reader = new FileReader()
let file = controlImageSelector.files[0]
reader.addEventListener("load", function(event) {
controlImagePreview.src = reader.result
})
if (file) {
reader.readAsDataURL(file)
}
}
controlImageSelector.addEventListener("change", loadControlnetImageFromFile)
function controlImageLoad() {
let w = controlImagePreview.naturalWidth
let h = controlImagePreview.naturalHeight
w = w - (w % 8)
h = h - (h % 8)
addImageSizeOption(w)
addImageSizeOption(h)
widthField.value = w
heightField.value = h
widthField.dispatchEvent(new Event("change"))
heightField.dispatchEvent(new Event("change"))
}
controlImagePreview.addEventListener("load", controlImageLoad)
function controlImageUnload() {
controlImageSelector.value = null
controlImagePreview.src = ""
controlImagePreview.dispatchEvent(new Event("unload"))
}
controlImageClearBtn.addEventListener("click", controlImageUnload)
promptsFromFileSelector.addEventListener("change", async function() {
if (promptsFromFileSelector.files.length === 0) {
return
@ -2123,6 +2375,11 @@ resumeBtn.addEventListener("click", function() {
document.body.classList.remove("wait-pause")
})
function onPing(event) {
tunnelUpdate(event)
packagesUpdate(event)
}
function tunnelUpdate(event) {
if ("cloudflare" in event) {
document.getElementById("cloudflare-off").classList.add("displayNone")
@ -2136,6 +2393,42 @@ function tunnelUpdate(event) {
}
}
let trtSettingsForced = false
function packagesUpdate(event) {
let trtBtn = document.getElementById("toggle-tensorrt-install")
let trtInstalled = "packages_installed" in event && "tensorrt" in event["packages_installed"]
if ("packages_installing" in event && event["packages_installing"].includes("tensorrt")) {
trtBtn.innerHTML = "Installing.."
trtBtn.disabled = true
} else {
trtBtn.innerHTML = trtInstalled ? "Uninstall" : "Install"
trtBtn.disabled = false
}
if (document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
document.querySelector("#enable_trt_config").classList.remove("displayNone")
document.querySelector("#trt-build-config").classList.remove("displayNone")
if (!trtSettingsForced) {
// settings for demo
promptField.value = "Dragons fighting with a knight, castle, war scene, fantasy, cartoon, flames, HD"
seedField.value = 3187947173
widthField.value = 1024
heightField.value = 768
randomSeedField.checked = false
seedField.disabled = false
stableDiffusionModelField.value = "sd-v1-4"
// numOutputsParallelField.classList.add("displayNone")
// document.querySelector("#num_outputs_parallel_label").classList.add("displayNone")
trtSettingsForced = true
}
}
}
document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", async function() {
let command = "stop"
if (document.getElementById("toggle-cloudflare-tunnel").innerHTML == "Start") {
@ -2155,6 +2448,63 @@ document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", as
console.log(`Cloudflare tunnel ${command} result:`, res)
})
document.getElementById("toggle-tensorrt-install").addEventListener("click", function(e) {
if (this.disabled === true) {
return
}
let command = this.innerHTML.toLowerCase()
let self = this
shiftOrConfirm(
e,
"Are you sure you want to " + command + " TensorRT?",
async function() {
showToast(`TensorRT ${command} started. Please wait.`)
self.disabled = true
if (command === "install") {
self.innerHTML = "Installing.."
} else if (command === "uninstall") {
self.innerHTML = "Uninstalling.."
}
if (command === "installing..") {
alert("Already installing TensorRT!")
return
}
if (command !== "install" && command !== "uninstall") {
return
}
let res = await fetch("/package/tensorrt", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
command: command,
}),
})
res = await res.json()
self.disabled = false
if (res.status === "OK") {
alert("TensorRT " + command + "ed successfully!")
self.innerHTML = command === "install" ? "Uninstall" : "Install"
} else if (res.status_code === 500) {
alert("TensorselfRT failed to " + command + ": " + res.detail)
self.innerHTML = command === "install" ? "Install" : "Uninstall"
}
console.log(`Package ${command} result:`, res)
},
false
)
})
/* Embeddings */
function updateEmbeddingsList(filter = "") {
@ -2171,7 +2521,10 @@ function updateEmbeddingsList(filter = "") {
} else {
let subdir = html(m[1], prefix + m[0] + "/", filter)
if (subdir != "") {
folders += `<div class="embedding-category"><h4 class="collapsible">${prefix}${m[0]}</h4><div class="collapsible-content">` + subdir + '</div></div>'
folders +=
`<div class="embedding-category"><h4 class="collapsible">${prefix}${m[0]}</h4><div class="collapsible-content">` +
subdir +
"</div></div>"
}
}
})
@ -2293,7 +2646,6 @@ embeddingsCollapsiblesBtn.addEventListener("click", (e) => {
}
})
if (testDiffusers.checked) {
document.getElementById("embeddings-container").classList.remove("displayNone")
}
@ -2390,3 +2742,172 @@ createLoraEntries()
// }
// document.querySelectorAll("input[type=number]").forEach(showSpinnerOnlyOnHover)
////////////////////////////// Image Size Widget //////////////////////////////////////////
function roundToMultiple(number, n) {
if (n == "") {
n = 1
}
return Math.round(number / n) * n
}
function addImageSizeOption(size) {
let sizes = Object.values(widthField.options).map((o) => o.value)
if (!sizes.includes(String(size))) {
sizes.push(String(size))
sizes.sort((a, b) => Number(a) - Number(b))
let option = document.createElement("option")
option.value = size
option.text = `${size}`
widthField.add(option, sizes.indexOf(String(size)))
heightField.add(option.cloneNode(true), sizes.indexOf(String(size)))
}
}
function setImageWidthHeight(w, h) {
let step = customWidthField.step
w = roundToMultiple(w, step)
h = roundToMultiple(h, step)
addImageSizeOption(w)
addImageSizeOption(h)
widthField.value = w
heightField.value = h
widthField.dispatchEvent(new Event("change"))
heightField.dispatchEvent(new Event("change"))
}
function enlargeImageSize(factor) {
let step = customWidthField.step
let w = roundToMultiple(widthField.value * factor, step)
let h = roundToMultiple(heightField.value * factor, step)
customWidthField.value = w
customHeightField.value = h
}
let recentResolutionsValues = []
;(function() {
///// Init resolutions dropdown
function makeResolutionButtons() {
recentResolutionList.innerHTML = ""
recentResolutionsValues.forEach((el) => {
let button = document.createElement("button")
button.classList.add("tertiaryButton")
button.style.width = "8em"
button.innerHTML = `${el.w}&times;${el.h}`
button.addEventListener("click", () => {
customWidthField.value = el.w
customHeightField.value = el.h
hidePopup()
})
recentResolutionList.appendChild(button)
recentResolutionList.appendChild(document.createElement("br"))
})
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
}
enlarge15Button.addEventListener("click", () => {
enlargeImageSize(1.5)
hidePopup()
})
enlarge2Button.addEventListener("click", () => {
enlargeImageSize(2)
hidePopup()
})
enlarge3Button.addEventListener("click", () => {
enlargeImageSize(3)
hidePopup()
})
customWidthField.addEventListener("change", () => {
let w = customWidthField.value
customWidthField.value = roundToMultiple(w, customWidthField.step)
if (w != customWidthField.value) {
showToast(`Rounded width to the closest multiple of ${customWidthField.step}.`)
}
})
customHeightField.addEventListener("change", () => {
let h = customHeightField.value
customHeightField.value = roundToMultiple(h, customHeightField.step)
if (h != customHeightField.value) {
showToast(`Rounded height to the closest multiple of ${customHeightField.step}.`)
}
})
makeImageBtn.addEventListener("click", () => {
let w = widthField.value
let h = heightField.value
recentResolutionsValues = recentResolutionsValues.filter((el) => el.w != w || el.h != h)
recentResolutionsValues.unshift({ w: w, h: h })
recentResolutionsValues = recentResolutionsValues.slice(0, 8)
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
makeResolutionButtons()
})
let _jsonstring = localStorage.recentResolutionsValues
if (_jsonstring == undefined) {
recentResolutionsValues = [
{ w: 512, h: 512 },
{ w: 640, h: 448 },
{ w: 448, h: 640 },
{ w: 512, h: 768 },
{ w: 768, h: 512 },
{ w: 1024, h: 768 },
{ w: 768, h: 1024 },
]
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
} else {
recentResolutionsValues = JSON.parse(localStorage.recentResolutionsValues)
}
makeResolutionButtons()
recentResolutionsValues.forEach((val) => {
addImageSizeOption(val.w)
addImageSizeOption(val.h)
})
function processClick(e) {
if (!recentResolutionsPopup.contains(e.target)) {
hidePopup()
}
}
function showPopup() {
customWidthField.value = widthField.value
customHeightField.value = heightField.value
recentResolutionsPopup.classList.remove("displayNone")
document.addEventListener("click", processClick)
}
function hidePopup() {
recentResolutionsPopup.classList.add("displayNone")
setImageWidthHeight(customWidthField.value, customHeightField.value)
document.removeEventListener("click", processClick)
}
recentResolutionsButton.addEventListener("click", (event) => {
if (recentResolutionsPopup.classList.contains("displayNone")) {
showPopup()
event.stopPropagation()
} else {
hidePopup()
}
})
swapWidthHeightButton.addEventListener("click", (event) => {
let temp = widthField.value
widthField.value = heightField.value
heightField.value = temp
})
})()

View File

@ -16,6 +16,7 @@ var ParameterType = {
*/
let parametersTable = document.querySelector("#system-settings-table")
let networkParametersTable = document.querySelector("#system-settings-network-table")
let installExtrasTable = document.querySelector("#system-settings-install-extras-table")
/**
* JSDoc style
@ -241,6 +242,29 @@ var PARAMETERS = [
render: () => '<button id="toggle-cloudflare-tunnel" class="primaryButton">Start</button>',
table: networkParametersTable,
},
{
id: "nvidia_tensorrt",
type: ParameterType.custom,
label: "NVIDIA TensorRT",
note: `Faster image generation by converting your Stable Diffusion models to the NVIDIA TensorRT format. You can choose the
models to convert. Download size: approximately 2 GB.<br/><br/>
<b>Early access version:</b> support for LoRA is still under development.
<div id="trt-build-config" class="displayNone">
<h3>Build Config:</h3>
Batch size range:
<label>Min:</label> <input id="trt-build-min-batch" type="number" min="1" value="1" style="width: 40pt" />
<label>Max:</label> <input id="trt-build-max-batch" type="number" min="1" value="1" style="width: 40pt" /><br/><br/>
<b>Build for resolutions</b>:<br/>
<input id="trt-build-res-512" type="checkbox" value="1" /> 512x512 to 768x768<br/>
<input id="trt-build-res-768" type="checkbox" value="1" checked /> 768x768 to 1024x1024<br/>
<input id="trt-build-res-1024" type="checkbox" value="1" /> 1024x1024 to 1280x1280<br/>
<input id="trt-build-res-1280" type="checkbox" value="1" /> 1280x1280 to 1536x1536<br/>
<input id="trt-build-res-1536" type="checkbox" value="1" /> 1536x1536 to 1792x1792<br/>
</div>`,
icon: "fa-angles-up",
render: () => '<button id="toggle-tensorrt-install" class="primaryButton">Install</button>',
table: installExtrasTable,
},
]
function getParameterSettingsEntry(id) {
@ -441,6 +465,8 @@ async function getAppConfig() {
document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => {
option.style.display = "none"
})
customWidthField.step = 64
customHeightField.step = 64
} else {
document.querySelector("#lora_model_container").style.display = ""
document.querySelector("#tiling_container").style.display = ""
@ -451,6 +477,8 @@ async function getAppConfig() {
document.querySelector("#clip_skip_config").classList.remove("displayNone")
document.querySelector("#embeddings-button").classList.remove("displayNone")
document.querySelector("#negative-embeddings-button").classList.remove("displayNone")
customWidthField.step = 8
customHeightField.step = 8
}
console.log("get config status response", config)
@ -582,6 +610,23 @@ function setDeviceInfo(devices) {
systemInfoEl.querySelector("#system-info-cpu").innerText = cpu
systemInfoEl.querySelector("#system-info-gpus-all").innerHTML = allGPUs.join("</br>")
systemInfoEl.querySelector("#system-info-rendering-devices").innerHTML = activeGPUs.join("</br>")
// tensorRT
if (devices.active && testDiffusers.checked && devices.enable_trt === true) {
let nvidiaGPUs = Object.keys(devices.active).filter((d) => {
let gpuName = devices.active[d].name
gpuName = gpuName.toLowerCase()
return (
gpuName.includes("nvidia") ||
gpuName.includes("geforce") ||
gpuName.includes("quadro") ||
gpuName.includes("tesla")
)
})
if (nvidiaGPUs.length > 0) {
document.querySelector("#install-extras-container").classList.remove("displayNone")
}
}
}
function setHostInfo(hosts) {
@ -744,10 +789,10 @@ navigator.permissions.query({ name: "clipboard-write" }).then(function(result) {
document.addEventListener("system_info_update", (e) => setDeviceInfo(e.detail))
useBetaChannelField.addEventListener('change', (e) => {
if (e.target.checked) {
getParameterSettingsEntry("test_diffusers").classList.remove('displayNone')
} else {
getParameterSettingsEntry("test_diffusers").classList.add('displayNone')
useBetaChannelField.addEventListener("change", (e) => {
if (e.target.checked) {
getParameterSettingsEntry("test_diffusers").classList.remove("displayNone")
} else {
getParameterSettingsEntry("test_diffusers").classList.add("displayNone")
}
})

View File

@ -552,17 +552,23 @@ class ModelDropdown {
this.createModelNodeList(`${folderName || ""}/${childFolderName}`, childModels, false)
)
} else {
let modelId = model
let modelName = model
if (typeof model === "object") {
modelId = Object.keys(model)[0]
modelName = model[modelId]
}
const classes = ["model-file"]
if (isRootFolder) {
classes.push("in-root-folder")
}
// Remove the leading slash from the model path
const fullPath = folderName ? `${folderName.substring(1)}/${model}` : model
const fullPath = folderName ? `${folderName.substring(1)}/${modelId}` : modelId
modelsMap.set(
model,
modelId,
createElement("li", { "data-path": fullPath }, classes, [
createElement("i", undefined, ["fa-regular", "fa-file", "icon"]),
model,
modelName,
])
)
}
@ -643,22 +649,6 @@ async function getModels(scanForMalicious = true) {
makeImageBtn.disabled = true
}
/* This code should no longer be needed. Commenting out for now, will cleanup later.
const sd_model_setting_key = "stable_diffusion_model"
const vae_model_setting_key = "vae_model"
const hypernetwork_model_key = "hypernetwork_model"
const stableDiffusionOptions = modelsOptions['stable-diffusion']
const vaeOptions = modelsOptions['vae']
const hypernetworkOptions = modelsOptions['hypernetwork']
// TODO: set default for model here too
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
if (getSetting(sd_model_setting_key) == '' || SETTINGS[sd_model_setting_key].value == '') {
setSetting(sd_model_setting_key, stableDiffusionOptions[0])
}
*/
// notify ModelDropdown objects to refresh
document.dispatchEvent(new Event("refreshModels"))
} catch (e) {
@ -667,4 +657,7 @@ async function getModels(scanForMalicious = true) {
}
// reload models button
document.querySelector("#reload-models").addEventListener("click", () => getModels())
document.querySelector("#reload-models").addEventListener("click", (e) => {
e.stopPropagation()
getModels()
})

View File

@ -109,8 +109,10 @@
imageObj.onload = function() {
// Calculate the maximum cropped dimensions
const maxCroppedWidth = Math.floor(this.width / 64) * 64;
const maxCroppedHeight = Math.floor(this.height / 64) * 64;
const step = customWidthField.step
const maxCroppedWidth = Math.floor(this.width / step) * step;
const maxCroppedHeight = Math.floor(this.height / step) * step;
canvas.width = maxCroppedWidth;
canvas.height = maxCroppedHeight;
@ -122,35 +124,17 @@
// Draw the image with centered coordinates
context.drawImage(imageObj, x, y, this.width, this.height);
initImagePreview.src = canvas.toDataURL('image/png');
let bestWidth = maxCroppedWidth - maxCroppedWidth % 8
let bestHeight = maxCroppedHeight - maxCroppedHeight % 8
// Get the options from widthField and heightField
const widthOptions = Array.from(widthField.options).map(option => parseInt(option.value));
const heightOptions = Array.from(heightField.options).map(option => parseInt(option.value));
// Find the closest aspect ratio and closest to original dimensions
let bestWidth = widthOptions[0];
let bestHeight = heightOptions[0];
let minDifference = Math.abs(maxCroppedWidth / maxCroppedHeight - bestWidth / bestHeight);
let minDistance = Math.abs(maxCroppedWidth - bestWidth) + Math.abs(maxCroppedHeight - bestHeight);
for (const width of widthOptions) {
for (const height of heightOptions) {
const difference = Math.abs(maxCroppedWidth / maxCroppedHeight - width / height);
const distance = Math.abs(maxCroppedWidth - width) + Math.abs(maxCroppedHeight - height);
if (difference < minDifference || (difference === minDifference && distance < minDistance)) {
minDifference = difference;
minDistance = distance;
bestWidth = width;
bestHeight = height;
}
}
}
addImageSizeOption(bestWidth)
addImageSizeOption(bestHeight)
// Set the width and height to the closest aspect ratio and closest to original dimensions
widthField.value = bestWidth;
heightField.value = bestHeight;
initImagePreview.src = canvas.toDataURL('image/png');
};
function handlePaste(e) {

View File

@ -0,0 +1,114 @@
/*
LoRA Prompt Parser 1.0
by Patrice
Copying and pasting a prompt with a LoRA tag will automatically select the corresponding option in the Easy Diffusion dropdown and remove the LoRA tag from the prompt. The LoRA must be already available in the corresponding Easy Diffusion dropdown (this is not a LoRA downloader).
*/
(function() {
"use strict"
promptField.addEventListener('input', function(e) {
const { LoRA, prompt } = extractLoraTags(e.target.value);
//console.log('e.target: ' + JSON.stringify(LoRA));
if (LoRA !== null && LoRA.length > 0) {
promptField.value = prompt.replace(/,+$/, ''); // remove any trailing ,
if (testDiffusers?.checked === false) {
showToast("LoRA's are only supported with diffusers. Just stripping the LoRA tag from the prompt.")
}
}
if (LoRA !== null && LoRA.length > 0 && testDiffusers?.checked) {
for (let i = 0; i < LoRA.length; i++) {
//if (loraModelField.value !== LoRA[0].lora_model) {
// Set the new LoRA value
//console.log("Loading info");
//console.log(LoRA[0].lora_model_0);
//console.log(JSON.stringify(LoRa));
let lora = `lora_model_${i}`;
let alpha = `lora_alpha_${i}`;
let loramodel = document.getElementById(lora);
let alphavalue = document.getElementById(alpha);
loramodel.setAttribute("data-path", LoRA[i].lora_model_0);
loramodel.value = LoRA[i].lora_model_0;
alphavalue.value = LoRA[i].lora_alpha_0;
if (i != LoRA.length - 1)
createLoraEntry();
}
//loraAlphaSlider.value = loraAlphaField.value * 100;
//TBD.value = LoRA[0].blockweights; // block weights not supported by ED at this time
//}
showToast("Prompt successfully processed", LoRA[0].lora_model_0);
//console.log('LoRa: ' + LoRA[0].lora_model_0);
//showToast("Prompt successfully processed", lora_model_0.value);
}
//promptField.dispatchEvent(new Event('change'));
});
function isModelAvailable(array, searchString) {
const foundItem = array.find(function(item) {
item = item.toString().toLowerCase();
return item === searchString.toLowerCase()
});
return foundItem || "";
}
// extract LoRA tags from strings
function extractLoraTags(prompt) {
// Define the regular expression for the tags
const regex = /<(?:lora|lyco):([^:>]+)(?::([^:>]*))?(?::([^:>]*))?>/gi
// Initialize an array to hold the matches
let matches = []
// Iterate over the string, finding matches
for (const match of prompt.matchAll(regex)) {
const modelFileName = isModelAvailable(modelsCache.options.lora, match[1].trim())
if (modelFileName !== "") {
// Initialize an object to hold a match
let loraTag = {
lora_model_0: modelFileName,
}
//console.log("Model:" + modelFileName);
// If weight is provided, add it to the loraTag object
if (match[2] !== undefined && match[2] !== '') {
loraTag.lora_alpha_0 = parseFloat(match[2].trim())
}
else
{
loraTag.lora_alpha_0 = 0.5
}
// If blockweights are provided, add them to the loraTag object
if (match[3] !== undefined && match[3] !== '') {
loraTag.blockweights = match[3].trim()
}
// Add the loraTag object to the array of matches
matches.push(loraTag);
//console.log(JSON.stringify(matches));
}
else
{
showToast("LoRA not found: " + match[1].trim(), 5000, true)
}
}
// Clean up the prompt string, e.g. from "apple, banana, <lora:...>, orange, <lora:...> , pear <lora:...>, <lora:...>" to "apple, banana, orange, pear"
let cleanedPrompt = prompt.replace(regex, '').replace(/(\s*,\s*(?=\s*,|$))|(^\s*,\s*)|\s+/g, ' ').trim();
//console.log('Matches: ' + JSON.stringify(matches));
// Return the array of matches and cleaned prompt string
return {
LoRA: matches,
prompt: cleanedPrompt
}
}
})()