mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-13 17:57:20 +02:00
Compare commits
56 Commits
v2.5.46.be
...
v2.5.48
Author | SHA1 | Date | |
---|---|---|---|
404329f9b5 | |||
3929e88d87 | |||
83a5b5b46f | |||
b97c906128 | |||
b8328b6071 | |||
9a528496a3 | |||
6a95c602b1 | |||
f0f6578b9c | |||
83c93eb9ef | |||
befe8ad24e | |||
c5249e6144 | |||
9be3297c27 | |||
b6344ef6f9 | |||
76b7e32125 | |||
801a3dd598 | |||
d1fdf1766a | |||
35073adc1f | |||
d76930c7f4 | |||
7d496f4ad0 | |||
53b5ce6e2c | |||
38ab5b090f | |||
fa58996f37 | |||
56f92ccab0 | |||
4e444b418e | |||
3d9a9299dc | |||
ae34c9e84b | |||
eba7bab15e | |||
ee6db85768 | |||
05ed110519 | |||
9690fd1fa8 | |||
4cee1be99c | |||
d39e1da183 | |||
8538a684e7 | |||
47d7513dd8 | |||
432fd57581 | |||
9c06e2612a | |||
1d6742f463 | |||
2e849827d1 | |||
1e2c9ecb41 | |||
14679586a8 | |||
11fb83a2a7 | |||
4d3f55622a | |||
eedf6f0aad | |||
13592fae1a | |||
4dd05d3efe | |||
2e3059a7c8 | |||
3b53b5ebaf | |||
a9f1000af8 | |||
a9960ded01 | |||
ed84a23f36 | |||
8301cafb37 | |||
c906c5d14a | |||
6e52680fa8 | |||
7f32c531d7 | |||
17a11b94b2 | |||
e61549e0cd |
@ -22,6 +22,12 @@
|
||||
Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed.
|
||||
|
||||
### Detailed changelog
|
||||
* 2.5.48 - 1 Aug 2023 - (beta-only) Full support for ControlNets. You can select a control image to guide the AI. You can pick a filter to pre-process the image, and one of the known (or custom) controlnet models. Supports `OpenPose`, `Canny`, `Straight Lines`, `Depth`, `Line Art`, `Scribble`, `Soft Edge`, `Shuffle` and `Segment`.
|
||||
* 2.5.47 - 30 Jul 2023 - An option to use `Strict Mask Border` while inpainting, to avoid touching areas outside the mask. But this might show a slight outline of the mask, which you will have to touch up separately.
|
||||
* 2.5.47 - 29 Jul 2023 - (beta-only) Fix long prompts with SDXL.
|
||||
* 2.5.47 - 29 Jul 2023 - (beta-only) Fix red dots in some SDXL images.
|
||||
* 2.5.47 - 29 Jul 2023 - Significantly faster `Fix Faces` and `Upscale` buttons (on the image). They no longer need to generate the image from scratch, instead they just upscale/fix the generated image in-place.
|
||||
* 2.5.47 - 28 Jul 2023 - Lots of internal code reorganization, in preparation for supporting Controlnets. No user-facing changes.
|
||||
* 2.5.46 - 27 Jul 2023 - (beta-only) Full support for SD-XL models (base and refiner)!
|
||||
* 2.5.45 - 24 Jul 2023 - (beta-only) Hide the samplers that won't be supported in the new diffusers version.
|
||||
* 2.5.45 - 22 Jul 2023 - (beta-only) Fix the recently-broken inpainting models.
|
||||
|
@ -18,7 +18,7 @@ os_name = platform.system()
|
||||
modules_to_check = {
|
||||
"torch": ("1.11.0", "1.13.1", "2.0.0"),
|
||||
"torchvision": ("0.12.0", "0.14.1", "0.15.1"),
|
||||
"sdkit": "1.0.142",
|
||||
"sdkit": "1.0.165",
|
||||
"stable-diffusion-sdkit": "2.1.4",
|
||||
"rich": "12.6.0",
|
||||
"uvicorn": "0.19.0",
|
||||
|
@ -32,6 +32,8 @@ logging.basicConfig(
|
||||
|
||||
SD_DIR = os.getcwd()
|
||||
|
||||
ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, ".."))
|
||||
|
||||
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||
|
@ -5,10 +5,11 @@ import traceback
|
||||
from typing import Union
|
||||
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.types import ModelsData
|
||||
from easydiffusion.utils import log
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
|
||||
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
|
||||
from sdkit.utils import hash_file_quick
|
||||
|
||||
KNOWN_MODEL_TYPES = [
|
||||
@ -19,6 +20,8 @@ KNOWN_MODEL_TYPES = [
|
||||
"realesrgan",
|
||||
"lora",
|
||||
"codeformer",
|
||||
"embeddings",
|
||||
"controlnet",
|
||||
]
|
||||
MODEL_EXTENSIONS = {
|
||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||
@ -29,6 +32,7 @@ MODEL_EXTENSIONS = {
|
||||
"lora": [".ckpt", ".safetensors"],
|
||||
"codeformer": [".pth"],
|
||||
"embeddings": [".pt", ".bin", ".safetensors"],
|
||||
"controlnet": [".pth", ".safetensors"],
|
||||
}
|
||||
DEFAULT_MODELS = {
|
||||
"stable-diffusion": [
|
||||
@ -57,7 +61,9 @@ def init():
|
||||
|
||||
|
||||
def load_default_models(context: Context):
|
||||
set_vram_optimizations(context)
|
||||
from easydiffusion import runtime
|
||||
|
||||
runtime.set_vram_optimizations(context)
|
||||
|
||||
config = app.getConfig()
|
||||
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
|
||||
@ -138,43 +144,32 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
||||
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
|
||||
|
||||
|
||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else ""
|
||||
upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else ""
|
||||
|
||||
model_paths_in_req = {
|
||||
"stable-diffusion": task_data.use_stable_diffusion_model,
|
||||
"vae": task_data.use_vae_model,
|
||||
"hypernetwork": task_data.use_hypernetwork_model,
|
||||
"codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None,
|
||||
"gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None,
|
||||
"realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None,
|
||||
"latent_upscaler": True if "latent_upscaler" in upscale_lower else None,
|
||||
"nsfw_checker": True if task_data.block_nsfw else None,
|
||||
"lora": task_data.use_lora_model,
|
||||
}
|
||||
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
|
||||
models_to_reload = {
|
||||
model_type: path
|
||||
for model_type, path in model_paths_in_req.items()
|
||||
if context.model_paths.get(model_type) != path
|
||||
for model_type, path in models_data.model_paths.items()
|
||||
if context.model_paths.get(model_type) != path or (path is not None and context.models.get(model_type) is None)
|
||||
}
|
||||
|
||||
if task_data.codeformer_upscale_faces:
|
||||
if models_data.model_paths.get("codeformer"):
|
||||
if "realesrgan" not in models_to_reload and "realesrgan" not in context.models:
|
||||
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
|
||||
models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
|
||||
elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None:
|
||||
del models_to_reload["realesrgan"] # don't unload realesrgan
|
||||
|
||||
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD
|
||||
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
||||
for model_type in models_to_force_reload:
|
||||
if model_type not in models_data.model_paths:
|
||||
continue
|
||||
models_to_reload[model_type] = models_data.model_paths[model_type]
|
||||
|
||||
for model_type, model_path_in_req in models_to_reload.items():
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||
extra_params = models_data.model_params.get(model_type, {})
|
||||
try:
|
||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||
action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already
|
||||
if model_type in context.model_load_errors:
|
||||
del context.model_load_errors[model_type]
|
||||
except Exception as e:
|
||||
@ -183,24 +178,22 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
|
||||
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(
|
||||
task_data.use_stable_diffusion_model, model_type="stable-diffusion"
|
||||
)
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
|
||||
task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora")
|
||||
|
||||
if task_data.use_face_correction:
|
||||
if "gfpgan" in task_data.use_face_correction.lower():
|
||||
model_type = "gfpgan"
|
||||
elif "codeformer" in task_data.use_face_correction.lower():
|
||||
model_type = "codeformer"
|
||||
def resolve_model_paths(models_data: ModelsData):
|
||||
model_paths = models_data.model_paths
|
||||
for model_type in model_paths:
|
||||
skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"]
|
||||
if model_type in skip_models: # doesn't use model paths
|
||||
continue
|
||||
if model_type == "codeformer":
|
||||
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
|
||||
elif model_type == "controlnet":
|
||||
model_id = model_paths[model_type]
|
||||
model_info = get_model_info_from_db(model_type=model_type, model_id=model_id)
|
||||
if model_info:
|
||||
filename = model_info.get("url", "").split("/")[-1]
|
||||
download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False)
|
||||
|
||||
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type)
|
||||
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
||||
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
|
||||
model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
|
||||
|
||||
|
||||
def fail_if_models_did_not_load(context: Context):
|
||||
@ -222,28 +215,17 @@ def download_default_models_if_necessary():
|
||||
print(model_type, "model(s) found.")
|
||||
|
||||
|
||||
def download_if_necessary(model_type: str, file_name: str, model_id: str):
|
||||
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
||||
model_path = os.path.join(app.MODELS_DIR, model_type, file_name)
|
||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||
|
||||
other_models_exist = any_model_exists(model_type)
|
||||
other_models_exist = any_model_exists(model_type) and skip_if_others_exist
|
||||
known_model_exists = os.path.exists(model_path)
|
||||
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
|
||||
|
||||
if known_model_is_corrupt or (not other_models_exist and not known_model_exists):
|
||||
print("> download", model_type, model_id)
|
||||
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR)
|
||||
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
config = app.getConfig()
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
|
||||
if vram_usage_level != context.vram_usage_level:
|
||||
context.vram_usage_level = vram_usage_level
|
||||
return True
|
||||
|
||||
return False
|
||||
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False)
|
||||
|
||||
|
||||
def migrate_legacy_model_location():
|
||||
@ -266,16 +248,6 @@ def any_model_exists(model_type: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def set_clip_skip(context: Context, task_data: TaskData):
|
||||
clip_skip = task_data.clip_skip
|
||||
|
||||
if clip_skip != context.clip_skip:
|
||||
context.clip_skip = clip_skip
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
@ -324,12 +296,26 @@ def is_malicious_model(file_path):
|
||||
def getModels(scan_for_malicious: bool = True):
|
||||
models = {
|
||||
"options": {
|
||||
"stable-diffusion": ["sd-v1-4"],
|
||||
"stable-diffusion": [{"sd-v1-4": "SD 1.4"}],
|
||||
"vae": [],
|
||||
"hypernetwork": [],
|
||||
"lora": [],
|
||||
"codeformer": ["codeformer"],
|
||||
"codeformer": [{"codeformer": "CodeFormer"}],
|
||||
"embeddings": [],
|
||||
"controlnet": [
|
||||
{"control_v11p_sd15_canny": "Canny (*)"},
|
||||
{"control_v11p_sd15_openpose": "OpenPose (*)"},
|
||||
{"control_v11p_sd15_normalbae": "Normal BAE (*)"},
|
||||
{"control_v11f1p_sd15_depth": "Depth (*)"},
|
||||
{"control_v11p_sd15_scribble": "Scribble"},
|
||||
{"control_v11p_sd15_softedge": "Soft Edge"},
|
||||
{"control_v11p_sd15_inpaint": "Inpaint"},
|
||||
{"control_v11p_sd15_lineart": "Line Art"},
|
||||
{"control_v11p_sd15s2_lineart_anime": "Line Art Anime"},
|
||||
{"control_v11p_sd15_mlsd": "Straight Lines"},
|
||||
{"control_v11p_sd15_seg": "Segment"},
|
||||
{"control_v11e_sd15_shuffle": "Shuffle"},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@ -338,9 +324,9 @@ def getModels(scan_for_malicious: bool = True):
|
||||
class MaliciousModelException(Exception):
|
||||
"Raised when picklescan reports a problem with a model"
|
||||
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]):
|
||||
tree = list(default_entries)
|
||||
nonlocal models_scanned
|
||||
tree = []
|
||||
for entry in sorted(
|
||||
os.scandir(directory),
|
||||
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
||||
@ -359,7 +345,14 @@ def getModels(scan_for_malicious: bool = True):
|
||||
raise MaliciousModelException(entry.path)
|
||||
if scan_for_malicious:
|
||||
known_models[entry.path] = mtime
|
||||
tree.append(entry.name[: -len(matching_suffix)])
|
||||
model_id = entry.name[: -len(matching_suffix)]
|
||||
model_exists = False
|
||||
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
|
||||
if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m):
|
||||
model_exists = True
|
||||
break
|
||||
if not model_exists:
|
||||
tree.append(model_id)
|
||||
elif entry.is_dir():
|
||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||
|
||||
@ -376,7 +369,8 @@ def getModels(scan_for_malicious: bool = True):
|
||||
os.makedirs(models_dir)
|
||||
|
||||
try:
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
||||
default_tree = models["options"].get(model_type, [])
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree)
|
||||
except MaliciousModelException as e:
|
||||
models["scan-error"] = str(e)
|
||||
|
||||
@ -389,6 +383,7 @@ def getModels(scan_for_malicious: bool = True):
|
||||
listModels(model_type="gfpgan")
|
||||
listModels(model_type="lora")
|
||||
listModels(model_type="embeddings")
|
||||
listModels(model_type="controlnet")
|
||||
|
||||
if scan_for_malicious and models_scanned > 0:
|
||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||
|
98
ui/easydiffusion/package_manager.py
Normal file
98
ui/easydiffusion/package_manager.py
Normal file
@ -0,0 +1,98 @@
|
||||
import sys
|
||||
import os
|
||||
import platform
|
||||
from importlib.metadata import version as pkg_version
|
||||
|
||||
from sdkit.utils import log
|
||||
|
||||
from easydiffusion import app
|
||||
|
||||
# future home of scripts/check_modules.py
|
||||
|
||||
manifest = {
|
||||
"tensorrt": {
|
||||
"install": [
|
||||
"nvidia-cudnn --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
|
||||
"tensorrt-libs --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
|
||||
"tensorrt --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
|
||||
],
|
||||
"uninstall": ["tensorrt"],
|
||||
# TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error)
|
||||
}
|
||||
}
|
||||
installing = []
|
||||
|
||||
# remove this once TRT releases on pypi
|
||||
if platform.system() == "Windows":
|
||||
trt_dir = os.path.join(app.ROOT_DIR, "tensorrt")
|
||||
if os.path.exists(trt_dir):
|
||||
files = os.listdir(trt_dir)
|
||||
|
||||
packages = manifest["tensorrt"]["install"]
|
||||
packages = tuple(p.replace("-", "_") for p in packages)
|
||||
|
||||
wheels = []
|
||||
for p in packages:
|
||||
p = p.split(" ")[0]
|
||||
f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None)
|
||||
if f:
|
||||
wheels.append(os.path.join(trt_dir, f))
|
||||
|
||||
manifest["tensorrt"]["install"] = wheels
|
||||
|
||||
|
||||
def get_installed_packages() -> list:
|
||||
return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)}
|
||||
|
||||
|
||||
def is_installed(module_name) -> bool:
|
||||
return version(module_name) is not None
|
||||
|
||||
|
||||
def install(module_name):
|
||||
if is_installed(module_name):
|
||||
log.info(f"{module_name} has already been installed!")
|
||||
return
|
||||
if module_name in installing:
|
||||
log.info(f"{module_name} is already installing!")
|
||||
return
|
||||
|
||||
if module_name not in manifest:
|
||||
raise RuntimeError(f"Can't install unknown package: {module_name}!")
|
||||
|
||||
commands = manifest[module_name]["install"]
|
||||
commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands]
|
||||
|
||||
installing.append(module_name)
|
||||
|
||||
try:
|
||||
for cmd in commands:
|
||||
print(">", cmd)
|
||||
if os.system(cmd) != 0:
|
||||
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
|
||||
finally:
|
||||
installing.remove(module_name)
|
||||
|
||||
|
||||
def uninstall(module_name):
|
||||
if not is_installed(module_name):
|
||||
log.info(f"{module_name} hasn't been installed!")
|
||||
return
|
||||
|
||||
if module_name not in manifest:
|
||||
raise RuntimeError(f"Can't uninstall unknown package: {module_name}!")
|
||||
|
||||
commands = manifest[module_name]["uninstall"]
|
||||
commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands]
|
||||
|
||||
for cmd in commands:
|
||||
print(">", cmd)
|
||||
if os.system(cmd) != 0:
|
||||
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
|
||||
|
||||
|
||||
def version(module_name: str) -> str:
|
||||
try:
|
||||
return pkg_version(module_name)
|
||||
except:
|
||||
return None
|
@ -1,279 +0,0 @@
|
||||
import json
|
||||
import pprint
|
||||
import queue
|
||||
import time
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import GenerateImageRequest
|
||||
from easydiffusion.types import Image as ResponseImage
|
||||
from easydiffusion.types import Response, TaskData, UserInitiatedStop
|
||||
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
|
||||
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
|
||||
from sdkit import Context
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.generate import generate_images
|
||||
from sdkit.models import load_model
|
||||
from sdkit.utils import (
|
||||
diffusers_latent_samples_to_images,
|
||||
gc,
|
||||
img_to_base64_str,
|
||||
img_to_buffer,
|
||||
latent_samples_to_images,
|
||||
get_device_usage,
|
||||
)
|
||||
|
||||
context = Context() # thread-local
|
||||
"""
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
"""
|
||||
|
||||
|
||||
def init(device):
|
||||
"""
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
"""
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
context.model_load_errors = {}
|
||||
context.enable_codeformer = True
|
||||
|
||||
from easydiffusion import app
|
||||
|
||||
app_config = app.getConfig()
|
||||
context.test_diffusers = (
|
||||
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
|
||||
)
|
||||
|
||||
log.info("Device usage during initialization:")
|
||||
get_device_usage(device, log_info=True, process_usage_only=False)
|
||||
|
||||
device_manager.device_init(context, device)
|
||||
|
||||
|
||||
def make_images(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
context.stop_processing = False
|
||||
print_task_info(req, task_data)
|
||||
|
||||
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
|
||||
res = Response(
|
||||
req,
|
||||
task_data,
|
||||
images=construct_response(images, seeds, task_data, base_seed=req.seed),
|
||||
)
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
log.info("Task completed")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"task data: {task_str}")
|
||||
|
||||
|
||||
def make_images_internal(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
images, user_stopped = generate_images_internal(
|
||||
req,
|
||||
task_data,
|
||||
data_queue,
|
||||
task_temp_images,
|
||||
step_callback,
|
||||
task_data.stream_image_progress,
|
||||
task_data.stream_image_progress_interval,
|
||||
)
|
||||
gc(context)
|
||||
filtered_images = filter_images(req, task_data, images, user_stopped)
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
save_images_to_disk(images, filtered_images, req, task_data)
|
||||
|
||||
seeds = [*range(req.seed, req.seed + len(images))]
|
||||
if task_data.show_only_filtered_image or filtered_images is images:
|
||||
return filtered_images, seeds
|
||||
else:
|
||||
return images + filtered_images, seeds + seeds
|
||||
|
||||
|
||||
def generate_images_internal(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
stream_image_progress_interval: int,
|
||||
):
|
||||
context.temp_images.clear()
|
||||
|
||||
callback = make_step_callback(
|
||||
req,
|
||||
task_data,
|
||||
data_queue,
|
||||
task_temp_images,
|
||||
step_callback,
|
||||
stream_image_progress,
|
||||
stream_image_progress_interval,
|
||||
)
|
||||
|
||||
try:
|
||||
if req.init_image is not None and not context.test_diffusers:
|
||||
req.sampler_name = "ddim"
|
||||
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
images = []
|
||||
user_stopped = True
|
||||
if context.partial_x_samples is not None:
|
||||
if context.test_diffusers:
|
||||
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
|
||||
else:
|
||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||
finally:
|
||||
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
|
||||
if not context.test_diffusers:
|
||||
del context.partial_x_samples
|
||||
context.partial_x_samples = None
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
|
||||
def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped):
|
||||
if user_stopped:
|
||||
return images
|
||||
|
||||
if task_data.block_nsfw:
|
||||
images = apply_filters(context, "nsfw_checker", images)
|
||||
|
||||
if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
|
||||
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
|
||||
prev_realesrgan_path = None
|
||||
if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
|
||||
prev_realesrgan_path = context.model_paths["realesrgan"]
|
||||
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
|
||||
load_model(context, "realesrgan")
|
||||
|
||||
try:
|
||||
images = apply_filters(
|
||||
context,
|
||||
"codeformer",
|
||||
images,
|
||||
upscale_faces=task_data.codeformer_upscale_faces,
|
||||
codeformer_fidelity=task_data.codeformer_fidelity,
|
||||
)
|
||||
finally:
|
||||
if prev_realesrgan_path:
|
||||
context.model_paths["realesrgan"] = prev_realesrgan_path
|
||||
load_model(context, "realesrgan")
|
||||
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
|
||||
images = apply_filters(context, "gfpgan", images)
|
||||
|
||||
if task_data.use_upscale:
|
||||
if "realesrgan" in task_data.use_upscale.lower():
|
||||
images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount)
|
||||
elif task_data.use_upscale == "latent_upscaler":
|
||||
images = apply_filters(
|
||||
context,
|
||||
"latent_upscaler",
|
||||
images,
|
||||
scale=task_data.upscale_amount,
|
||||
latent_upscaler_options={
|
||||
"prompt": req.prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"seed": req.seed,
|
||||
"num_inference_steps": task_data.latent_upscaler_steps,
|
||||
"guidance_scale": 0,
|
||||
},
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=img_to_base64_str(
|
||||
img,
|
||||
task_data.output_format,
|
||||
task_data.output_quality,
|
||||
task_data.output_lossless,
|
||||
),
|
||||
seed=seed,
|
||||
)
|
||||
for img, seed in zip(images, seeds)
|
||||
]
|
||||
|
||||
|
||||
def make_step_callback(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
stream_image_progress_interval: int,
|
||||
):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
last_callback_time = -1
|
||||
|
||||
def update_temp_img(x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
|
||||
if context.test_diffusers:
|
||||
images = diffusers_latent_samples_to_images(context, x_samples)
|
||||
else:
|
||||
images = latent_samples_to_images(context, x_samples)
|
||||
|
||||
if task_data.block_nsfw:
|
||||
images = apply_filters(context, "nsfw_checker", images)
|
||||
|
||||
for i, img in enumerate(images):
|
||||
buf = img_to_buffer(img, output_format="JPEG")
|
||||
|
||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
del images
|
||||
return partial_images
|
||||
|
||||
def on_image_step(x_samples, i, *args):
|
||||
nonlocal last_callback_time
|
||||
|
||||
if context.test_diffusers:
|
||||
context.partial_x_samples = (x_samples, args[0])
|
||||
else:
|
||||
context.partial_x_samples = x_samples
|
||||
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||
|
||||
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
|
||||
progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images)
|
||||
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
step_callback()
|
||||
|
||||
if context.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
return on_image_step
|
53
ui/easydiffusion/runtime.py
Normal file
53
ui/easydiffusion/runtime.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
A runtime that runs on a specific device (in a thread).
|
||||
|
||||
It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context.
|
||||
|
||||
This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function.
|
||||
"""
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.utils import log
|
||||
from sdkit import Context
|
||||
from sdkit.utils import get_device_usage
|
||||
|
||||
context = Context() # thread-local
|
||||
"""
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
"""
|
||||
|
||||
|
||||
def init(device):
|
||||
"""
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
"""
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
context.model_load_errors = {}
|
||||
context.enable_codeformer = True
|
||||
|
||||
from easydiffusion import app
|
||||
|
||||
app_config = app.getConfig()
|
||||
context.test_diffusers = (
|
||||
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
|
||||
)
|
||||
|
||||
log.info("Device usage during initialization:")
|
||||
get_device_usage(device, log_info=True, process_usage_only=False)
|
||||
|
||||
device_manager.device_init(context, device)
|
||||
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
from easydiffusion import app
|
||||
|
||||
config = app.getConfig()
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
|
||||
if vram_usage_level != context.vram_usage_level:
|
||||
context.vram_usage_level = vram_usage_level
|
||||
return True
|
||||
|
||||
return False
|
@ -8,8 +8,17 @@ import os
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
|
||||
from easydiffusion import app, model_manager, task_manager, package_manager
|
||||
from easydiffusion.tasks import RenderTask, FilterTask
|
||||
from easydiffusion.types import (
|
||||
GenerateImageRequest,
|
||||
FilterImageRequest,
|
||||
MergeRequest,
|
||||
TaskData,
|
||||
ModelsData,
|
||||
OutputFormatData,
|
||||
convert_legacy_render_req_to_new,
|
||||
)
|
||||
from easydiffusion.utils import log
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@ -97,6 +106,10 @@ def init():
|
||||
def render(req: dict):
|
||||
return render_internal(req)
|
||||
|
||||
@server_api.post("/filter")
|
||||
def render(req: dict):
|
||||
return filter_internal(req)
|
||||
|
||||
@server_api.post("/model/merge")
|
||||
def model_merge(req: dict):
|
||||
print(req)
|
||||
@ -122,6 +135,10 @@ def init():
|
||||
def stop_cloudflare_tunnel(req: dict):
|
||||
return stop_cloudflare_tunnel_internal(req)
|
||||
|
||||
@server_api.post("/package/{package_name:str}")
|
||||
def modify_package(package_name: str, req: dict):
|
||||
return modify_package_internal(package_name, req)
|
||||
|
||||
@server_api.get("/")
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
|
||||
@ -213,24 +230,36 @@ def ping_internal(session_id: str = None):
|
||||
if task_manager.current_state_error:
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail="Render thread is dead.")
|
||||
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
|
||||
# Alive
|
||||
response = {"status": str(task_manager.current_state)}
|
||||
|
||||
if session_id:
|
||||
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||
response["tasks"] = {id(t): t.status for t in session.tasks}
|
||||
|
||||
response["devices"] = task_manager.get_devices()
|
||||
response["packages_installed"] = package_manager.get_installed_packages()
|
||||
response["packages_installing"] = package_manager.installing
|
||||
|
||||
if cloudflare.address != None:
|
||||
response["cloudflare"] = cloudflare.address
|
||||
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
|
||||
def render_internal(req: dict):
|
||||
try:
|
||||
req = convert_legacy_render_req_to_new(req)
|
||||
|
||||
# separate out the request data into rendering and task-specific data
|
||||
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
|
||||
task_data: TaskData = TaskData.parse_obj(req)
|
||||
models_data: ModelsData = ModelsData.parse_obj(req)
|
||||
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
|
||||
|
||||
# Overwrite user specified save path
|
||||
config = app.getConfig()
|
||||
@ -240,28 +269,53 @@ def render_internal(req: dict):
|
||||
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
|
||||
|
||||
app.save_to_config(
|
||||
task_data.use_stable_diffusion_model,
|
||||
task_data.use_vae_model,
|
||||
task_data.use_hypernetwork_model,
|
||||
models_data.model_paths.get("stable-diffusion"),
|
||||
models_data.model_paths.get("vae"),
|
||||
models_data.model_paths.get("hypernetwork"),
|
||||
task_data.vram_usage_level,
|
||||
)
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(render_req, task_data)
|
||||
task = RenderTask(render_req, task_data, models_data, output_format)
|
||||
return enqueue_task(task)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def filter_internal(req: dict):
|
||||
try:
|
||||
session_id = req.get("session_id", "session")
|
||||
filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req)
|
||||
models_data: ModelsData = ModelsData.parse_obj(req)
|
||||
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
|
||||
|
||||
# enqueue the task
|
||||
task = FilterTask(filter_req, session_id, models_data, output_format)
|
||||
return enqueue_task(task)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def enqueue_task(task):
|
||||
try:
|
||||
task_manager.enqueue_task(task)
|
||||
response = {
|
||||
"status": str(task_manager.current_state),
|
||||
"queue": len(task_manager.tasks_queue),
|
||||
"stream": f"/image/stream/{id(new_task)}",
|
||||
"task": id(new_task),
|
||||
"stream": f"/image/stream/{task.id}",
|
||||
"task": task.id,
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def model_merge_internal(req: dict):
|
||||
@ -381,3 +435,19 @@ def stop_cloudflare_tunnel_internal(req: dict):
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def modify_package_internal(package_name: str, req: dict):
|
||||
try:
|
||||
cmd = req["command"]
|
||||
if cmd not in ("install", "uninstall"):
|
||||
raise RuntimeError(f"Unknown command: {cmd}")
|
||||
|
||||
cmd = getattr(package_manager, cmd)
|
||||
cmd(package_name)
|
||||
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -17,7 +17,7 @@ from typing import Any, Hashable
|
||||
|
||||
import torch
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||
from easydiffusion.tasks import Task
|
||||
from easydiffusion.utils import log
|
||||
from sdkit.utils import gc
|
||||
|
||||
@ -27,6 +27,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||
|
||||
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
||||
MAX_OVERLOAD_ALLOWED_RATIO = 2 # i.e. 2x pending tasks compared to the number of render threads
|
||||
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
@ -58,46 +59,6 @@ class ServerStates:
|
||||
pass
|
||||
|
||||
|
||||
class RenderTask: # Task with output queue and completion lock.
|
||||
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||
task_data.request_id = id(self)
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
self.error: Exception = None
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
|
||||
async def read_buffer_generator(self):
|
||||
try:
|
||||
while not self.buffer_queue.empty():
|
||||
res = self.buffer_queue.get(block=False)
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except queue.Empty as e:
|
||||
yield
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return "running"
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return "stopped"
|
||||
if self.error:
|
||||
return "error"
|
||||
if not self.buffer_queue.empty():
|
||||
return "buffer"
|
||||
if self.response:
|
||||
return "completed"
|
||||
return "pending"
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class DataCache:
|
||||
def __init__(self):
|
||||
@ -123,8 +84,8 @@ class DataCache:
|
||||
# Remove Items
|
||||
for key in to_delete:
|
||||
(_, val) = self._base[key]
|
||||
if isinstance(val, RenderTask):
|
||||
log.debug(f"RenderTask {key} expired. Data removed.")
|
||||
if isinstance(val, Task):
|
||||
log.debug(f"Task {key} expired. Data removed.")
|
||||
elif isinstance(val, SessionState):
|
||||
log.debug(f"Session {key} expired. Data removed.")
|
||||
else:
|
||||
@ -220,8 +181,8 @@ class SessionState:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
def put(self, task, ttl=TASK_TTL):
|
||||
task_id = id(task)
|
||||
def put(self, task: Task, ttl=TASK_TTL):
|
||||
task_id = task.id
|
||||
self._tasks_ids.append(task_id)
|
||||
if not task_cache.put(task_id, task, ttl):
|
||||
return False
|
||||
@ -230,11 +191,16 @@ class SessionState:
|
||||
return True
|
||||
|
||||
|
||||
def keep_task_alive(task: Task):
|
||||
task_cache.keep(task.id, TASK_TTL)
|
||||
session_cache.keep(task.session_id, TASK_TTL)
|
||||
|
||||
|
||||
def thread_get_next_task():
|
||||
from easydiffusion import renderer
|
||||
from easydiffusion import runtime
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
|
||||
log.warn(f"Render thread on device: {runtime.context.device} failed to acquire manager lock.")
|
||||
return None
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
@ -242,7 +208,7 @@ def thread_get_next_task():
|
||||
task = None
|
||||
try: # Select a render task.
|
||||
for queued_task in tasks_queue:
|
||||
if queued_task.render_device and renderer.context.device != queued_task.render_device:
|
||||
if queued_task.render_device and runtime.context.device != queued_task.render_device:
|
||||
# Is asking for a specific render device.
|
||||
if is_alive(queued_task.render_device) > 0:
|
||||
continue # requested device alive, skip current one.
|
||||
@ -251,7 +217,7 @@ def thread_get_next_task():
|
||||
queued_task.error = Exception(queued_task.render_device + " is not currently active.")
|
||||
task = queued_task
|
||||
break
|
||||
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
|
||||
if not queued_task.render_device and runtime.context.device == "cpu" and is_alive() > 1:
|
||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||
task = queued_task
|
||||
@ -266,19 +232,19 @@ def thread_get_next_task():
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from easydiffusion import model_manager, renderer
|
||||
from easydiffusion import model_manager, runtime
|
||||
|
||||
try:
|
||||
renderer.init(device)
|
||||
runtime.init(device)
|
||||
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
"device": renderer.context.device,
|
||||
"device_name": renderer.context.device_name,
|
||||
"device": runtime.context.device,
|
||||
"device_name": runtime.context.device_name,
|
||||
"alive": True,
|
||||
}
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.load_default_models(renderer.context)
|
||||
model_manager.load_default_models(runtime.context)
|
||||
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
@ -290,8 +256,8 @@ def thread_render(device):
|
||||
session_cache.clean()
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]["alive"]:
|
||||
log.info(f"Shutting down thread for device {renderer.context.device}")
|
||||
model_manager.unload_all(renderer.context)
|
||||
log.info(f"Shutting down thread for device {runtime.context.device}")
|
||||
model_manager.unload_all(runtime.context)
|
||||
return
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
@ -311,62 +277,31 @@ def thread_render(device):
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
|
||||
log.info(f"Session {task.session_id} starting task {task.id} on {runtime.context.device_name}")
|
||||
if not task.lock.acquire(blocking=False):
|
||||
raise Exception("Got locked task from queue.")
|
||||
try:
|
||||
task.run()
|
||||
|
||||
def step_callback():
|
||||
global current_state_error
|
||||
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
|
||||
if (
|
||||
isinstance(current_state_error, SystemExit)
|
||||
or isinstance(current_state_error, StopAsyncIteration)
|
||||
or isinstance(task.error, StopAsyncIteration)
|
||||
):
|
||||
renderer.context.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(task.task_data)
|
||||
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||
model_manager.fail_if_models_did_not_load(renderer.context)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = renderer.make_images(
|
||||
task.render_request,
|
||||
task.task_data,
|
||||
task.buffer_queue,
|
||||
task.temp_images,
|
||||
step_callback,
|
||||
)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
keep_task_alive(task)
|
||||
except Exception as e:
|
||||
task.error = str(e)
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
log.error(traceback.format_exc())
|
||||
finally:
|
||||
gc(renderer.context)
|
||||
gc(runtime.context)
|
||||
task.lock.release()
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
|
||||
keep_task_alive(task)
|
||||
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
|
||||
log.info(f"Session {task.session_id} task {task.id} cancelled!")
|
||||
elif task.error is not None:
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
|
||||
log.info(f"Session {task.session_id} task {task.id} failed!")
|
||||
else:
|
||||
log.info(
|
||||
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
|
||||
)
|
||||
log.info(f"Session {task.session_id} task {task.id} completed by {runtime.context.device_name}.")
|
||||
current_state = ServerStates.Online
|
||||
|
||||
|
||||
@ -438,6 +373,12 @@ def get_devices():
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
# temp until TRT releases
|
||||
import os
|
||||
from easydiffusion import app
|
||||
|
||||
devices["enable_trt"] = os.path.exists(os.path.join(app.ROOT_DIR, "tensorrt"))
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
@ -548,28 +489,27 @@ def shutdown_event(): # Signal render thread to close on shutdown
|
||||
current_state_error = SystemExit("Application shutting down.")
|
||||
|
||||
|
||||
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
def enqueue_task(task: Task):
|
||||
current_thread_count = is_alive()
|
||||
if current_thread_count <= 0: # Render thread is dead
|
||||
raise ChildProcessError("Rendering thread has died.")
|
||||
|
||||
# Alive, check if task in cache
|
||||
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||
session = get_cached_session(task.session_id, update_ttl=True)
|
||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||
if current_thread_count < len(pending_tasks):
|
||||
if len(pending_tasks) > current_thread_count * MAX_OVERLOAD_ALLOWED_RATIO:
|
||||
raise ConnectionRefusedError(
|
||||
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
|
||||
f"Session {task.session_id} already has {len(pending_tasks)} pending tasks, with {current_thread_count} workers."
|
||||
)
|
||||
|
||||
new_task = RenderTask(render_req, task_data)
|
||||
if session.put(new_task, TASK_TTL):
|
||||
if session.put(task, TASK_TTL):
|
||||
# Use twice the normal timeout for adding user requests.
|
||||
# Tries to force session.put to fail before tasks_queue.put would.
|
||||
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||
try:
|
||||
tasks_queue.append(new_task)
|
||||
tasks_queue.append(task)
|
||||
idle_event.set()
|
||||
return new_task
|
||||
return task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
raise RuntimeError("Failed to add task to cache.")
|
||||
|
3
ui/easydiffusion/tasks/__init__.py
Normal file
3
ui/easydiffusion/tasks/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .task import Task
|
||||
from .render_images import RenderTask
|
||||
from .filter_images import FilterTask
|
110
ui/easydiffusion/tasks/filter_images.py
Normal file
110
ui/easydiffusion/tasks/filter_images.py
Normal file
@ -0,0 +1,110 @@
|
||||
import json
|
||||
import pprint
|
||||
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.models import load_model
|
||||
from sdkit.utils import img_to_base64_str, log
|
||||
|
||||
from easydiffusion import model_manager, runtime
|
||||
from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData
|
||||
|
||||
from .task import Task
|
||||
|
||||
|
||||
class FilterTask(Task):
|
||||
"For applying filters to input images"
|
||||
|
||||
def __init__(
|
||||
self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData
|
||||
):
|
||||
super().__init__(session_id)
|
||||
|
||||
self.request = req
|
||||
self.models_data = models_data
|
||||
self.output_format = output_format
|
||||
|
||||
# convert to multi-filter format, if necessary
|
||||
if isinstance(req.filter, str):
|
||||
req.filter_params = {req.filter: req.filter_params}
|
||||
req.filter = [req.filter]
|
||||
|
||||
if not isinstance(req.image, list):
|
||||
req.image = [req.image]
|
||||
|
||||
def run(self):
|
||||
"Runs the image filtering task on the assigned thread"
|
||||
|
||||
context = runtime.context
|
||||
|
||||
model_manager.resolve_model_paths(self.models_data)
|
||||
model_manager.reload_models_if_necessary(context, self.models_data)
|
||||
model_manager.fail_if_models_did_not_load(context)
|
||||
|
||||
print_task_info(self.request, self.models_data, self.output_format)
|
||||
|
||||
images = filter_images(context, self.request.image, self.request.filter, self.request.filter_params)
|
||||
|
||||
output_format = self.output_format
|
||||
images = [
|
||||
img_to_base64_str(
|
||||
img, output_format.output_format, output_format.output_quality, output_format.output_lossless
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
|
||||
res = FilterImageResponse(self.request, self.models_data, images=images)
|
||||
res = res.json()
|
||||
self.buffer_queue.put(json.dumps(res))
|
||||
log.info("Filter task completed")
|
||||
|
||||
self.response = res
|
||||
|
||||
|
||||
def filter_images(context, images, filters, filter_params={}):
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
|
||||
for filter_name in filters:
|
||||
params = filter_params.get(filter_name, {})
|
||||
|
||||
previous_state = before_filter(context, filter_name, params)
|
||||
|
||||
try:
|
||||
images = apply_filters(context, filter_name, images, **params)
|
||||
finally:
|
||||
after_filter(context, filter_name, params, previous_state)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def before_filter(context, filter_name, filter_params):
|
||||
if filter_name == "codeformer":
|
||||
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
|
||||
|
||||
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
|
||||
prev_realesrgan_path = None
|
||||
|
||||
upscale_faces = filter_params.get("upscale_faces", False)
|
||||
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
|
||||
prev_realesrgan_path = context.model_paths.get("realesrgan")
|
||||
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
|
||||
load_model(context, "realesrgan")
|
||||
|
||||
return prev_realesrgan_path
|
||||
|
||||
|
||||
def after_filter(context, filter_name, filter_params, previous_state):
|
||||
if filter_name == "codeformer":
|
||||
prev_realesrgan_path = previous_state
|
||||
if prev_realesrgan_path:
|
||||
context.model_paths["realesrgan"] = prev_realesrgan_path
|
||||
load_model(context, "realesrgan")
|
||||
|
||||
|
||||
def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData):
|
||||
req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[")
|
||||
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
|
||||
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
|
||||
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"models data: {models_data}")
|
||||
log.info(f"output format: {output_format}")
|
340
ui/easydiffusion/tasks/render_images.py
Normal file
340
ui/easydiffusion/tasks/render_images.py
Normal file
@ -0,0 +1,340 @@
|
||||
import json
|
||||
import pprint
|
||||
import queue
|
||||
import time
|
||||
|
||||
from easydiffusion import model_manager, runtime
|
||||
from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData
|
||||
from easydiffusion.types import Image as ResponseImage
|
||||
from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop
|
||||
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
|
||||
from sdkit.generate import generate_images
|
||||
from sdkit.utils import (
|
||||
diffusers_latent_samples_to_images,
|
||||
gc,
|
||||
img_to_base64_str,
|
||||
img_to_buffer,
|
||||
latent_samples_to_images,
|
||||
log,
|
||||
)
|
||||
|
||||
from .task import Task
|
||||
from .filter_images import filter_images
|
||||
|
||||
|
||||
class RenderTask(Task):
|
||||
"For image generation"
|
||||
|
||||
def __init__(
|
||||
self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
|
||||
):
|
||||
super().__init__(task_data.session_id)
|
||||
|
||||
task_data.request_id = self.id
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.models_data = models_data
|
||||
self.output_format = output_format
|
||||
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
|
||||
def run(self):
|
||||
"Runs the image generation task on the assigned thread"
|
||||
|
||||
from easydiffusion import task_manager
|
||||
|
||||
context = runtime.context
|
||||
|
||||
def step_callback():
|
||||
task_manager.keep_task_alive(self)
|
||||
task_manager.current_state = task_manager.ServerStates.Rendering
|
||||
|
||||
if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance(
|
||||
self.error, StopAsyncIteration
|
||||
):
|
||||
context.stop_processing = True
|
||||
if isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||
self.error = task_manager.current_state_error
|
||||
task_manager.current_state_error = None
|
||||
log.info(f"Session {self.session_id} sent cancel signal for task {self.id}")
|
||||
|
||||
task_manager.current_state = task_manager.ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(self.models_data)
|
||||
|
||||
models_to_force_reload = []
|
||||
if (
|
||||
runtime.set_vram_optimizations(context)
|
||||
or self.has_param_changed(context, "clip_skip")
|
||||
or self.trt_needs_reload(context)
|
||||
):
|
||||
models_to_force_reload.append("stable-diffusion")
|
||||
|
||||
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
|
||||
model_manager.fail_if_models_did_not_load(context)
|
||||
|
||||
task_manager.current_state = task_manager.ServerStates.Rendering
|
||||
self.response = make_images(
|
||||
context,
|
||||
self.render_request,
|
||||
self.task_data,
|
||||
self.models_data,
|
||||
self.output_format,
|
||||
self.buffer_queue,
|
||||
self.temp_images,
|
||||
step_callback,
|
||||
)
|
||||
|
||||
def has_param_changed(self, context, param_name):
|
||||
if not context.test_diffusers:
|
||||
return False
|
||||
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
|
||||
return True
|
||||
|
||||
model = context.models["stable-diffusion"]
|
||||
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
|
||||
return model["params"].get(param_name) != new_val
|
||||
|
||||
def trt_needs_reload(self, context):
|
||||
if not context.test_diffusers:
|
||||
return False
|
||||
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
|
||||
return True
|
||||
|
||||
model = context.models["stable-diffusion"]
|
||||
|
||||
# curr_convert_to_trt = model["params"].get("convert_to_tensorrt")
|
||||
new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False)
|
||||
|
||||
pipe = model["default"]
|
||||
is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr(
|
||||
pipe.unet, "_allocate_trt_buffers_backup"
|
||||
)
|
||||
if new_convert_to_trt and not is_trt_loaded:
|
||||
return True
|
||||
|
||||
curr_build_config = model["params"].get("trt_build_config")
|
||||
new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {})
|
||||
|
||||
return new_convert_to_trt and curr_build_config != new_build_config
|
||||
|
||||
|
||||
def make_images(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
output_format: OutputFormatData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
context.stop_processing = False
|
||||
print_task_info(req, task_data, models_data, output_format)
|
||||
|
||||
images, seeds = make_images_internal(
|
||||
context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback
|
||||
)
|
||||
|
||||
res = GenerateImageResponse(
|
||||
req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format)
|
||||
)
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
log.info("Task completed")
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def print_task_info(
|
||||
req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
|
||||
):
|
||||
req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
|
||||
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
|
||||
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"task data: {task_str}")
|
||||
# log.info(f"models data: {models_data}")
|
||||
log.info(f"output format: {output_format}")
|
||||
|
||||
|
||||
def make_images_internal(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
output_format: OutputFormatData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
images, user_stopped = generate_images_internal(
|
||||
context,
|
||||
req,
|
||||
task_data,
|
||||
models_data,
|
||||
data_queue,
|
||||
task_temp_images,
|
||||
step_callback,
|
||||
task_data.stream_image_progress,
|
||||
task_data.stream_image_progress_interval,
|
||||
)
|
||||
|
||||
gc(context)
|
||||
|
||||
filters, filter_params = task_data.filters, task_data.filter_params
|
||||
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
save_images_to_disk(images, filtered_images, req, task_data, output_format)
|
||||
|
||||
seeds = [*range(req.seed, req.seed + len(images))]
|
||||
if task_data.show_only_filtered_image or filtered_images is images:
|
||||
return filtered_images, seeds
|
||||
else:
|
||||
return images + filtered_images, seeds + seeds
|
||||
|
||||
|
||||
def generate_images_internal(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
stream_image_progress_interval: int,
|
||||
):
|
||||
context.temp_images.clear()
|
||||
|
||||
callback = make_step_callback(
|
||||
context,
|
||||
req,
|
||||
task_data,
|
||||
data_queue,
|
||||
task_temp_images,
|
||||
step_callback,
|
||||
stream_image_progress,
|
||||
stream_image_progress_interval,
|
||||
)
|
||||
|
||||
try:
|
||||
if req.init_image is not None and not context.test_diffusers:
|
||||
req.sampler_name = "ddim"
|
||||
|
||||
req.width, req.height = map(lambda x: x - x % 8, (req.width, req.height)) # clamp to 8
|
||||
|
||||
if req.control_image and task_data.control_filter_to_apply:
|
||||
req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0]
|
||||
|
||||
if context.test_diffusers:
|
||||
pipe = context.models["stable-diffusion"]["default"]
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers_backup"):
|
||||
setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup)
|
||||
delattr(pipe.unet, "_allocate_trt_buffers_backup")
|
||||
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers"):
|
||||
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
|
||||
if convert_to_trt:
|
||||
pipe.unet.forward = pipe.unet._trt_forward
|
||||
# pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward
|
||||
log.info(f"Setting unet.forward to TensorRT")
|
||||
else:
|
||||
log.info(f"Not using TensorRT for unet.forward")
|
||||
pipe.unet.forward = pipe.unet._non_trt_forward
|
||||
# pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward
|
||||
setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers)
|
||||
delattr(pipe.unet, "_allocate_trt_buffers")
|
||||
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
images = []
|
||||
user_stopped = True
|
||||
if context.partial_x_samples is not None:
|
||||
if context.test_diffusers:
|
||||
images = diffusers_latent_samples_to_images(context, context.partial_x_samples)
|
||||
else:
|
||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||
finally:
|
||||
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
|
||||
if not context.test_diffusers:
|
||||
del context.partial_x_samples
|
||||
context.partial_x_samples = None
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
|
||||
def construct_response(images: list, seeds: list, output_format: OutputFormatData):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=img_to_base64_str(
|
||||
img,
|
||||
output_format.output_format,
|
||||
output_format.output_quality,
|
||||
output_format.output_lossless,
|
||||
),
|
||||
seed=seed,
|
||||
)
|
||||
for img, seed in zip(images, seeds)
|
||||
]
|
||||
|
||||
|
||||
def make_step_callback(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
stream_image_progress_interval: int,
|
||||
):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
last_callback_time = -1
|
||||
|
||||
def update_temp_img(x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
|
||||
if context.test_diffusers:
|
||||
images = diffusers_latent_samples_to_images(context, x_samples)
|
||||
else:
|
||||
images = latent_samples_to_images(context, x_samples)
|
||||
|
||||
if task_data.block_nsfw:
|
||||
images = filter_images(context, images, "nsfw_checker")
|
||||
|
||||
for i, img in enumerate(images):
|
||||
buf = img_to_buffer(img, output_format="JPEG")
|
||||
|
||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
del images
|
||||
return partial_images
|
||||
|
||||
def on_image_step(x_samples, i, *args):
|
||||
nonlocal last_callback_time
|
||||
|
||||
if context.test_diffusers:
|
||||
context.partial_x_samples = (x_samples, args[0])
|
||||
else:
|
||||
context.partial_x_samples = x_samples
|
||||
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||
|
||||
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
|
||||
progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images)
|
||||
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
step_callback()
|
||||
|
||||
if context.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
return on_image_step
|
47
ui/easydiffusion/tasks/task.py
Normal file
47
ui/easydiffusion/tasks/task.py
Normal file
@ -0,0 +1,47 @@
|
||||
from threading import Lock
|
||||
from queue import Queue, Empty as EmptyQueueException
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Task:
|
||||
"Task with output queue and completion lock"
|
||||
|
||||
def __init__(self, session_id):
|
||||
self.id = id(self)
|
||||
self.session_id = session_id
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.error: Exception = None
|
||||
self.lock: Lock = Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: Queue = Queue() # Queue of JSON string segments
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
|
||||
async def read_buffer_generator(self):
|
||||
try:
|
||||
while not self.buffer_queue.empty():
|
||||
res = self.buffer_queue.get(block=False)
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except EmptyQueueException as e:
|
||||
yield
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return "running"
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return "stopped"
|
||||
if self.error:
|
||||
return "error"
|
||||
if not self.buffer_queue.empty():
|
||||
return "buffer"
|
||||
if self.response:
|
||||
return "completed"
|
||||
return "pending"
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
def run(self):
|
||||
"Override this to implement the task's behavior"
|
||||
pass
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -17,8 +17,11 @@ class GenerateImageRequest(BaseModel):
|
||||
|
||||
init_image: Any = None
|
||||
init_image_mask: Any = None
|
||||
control_image: Any = None
|
||||
control_alpha: Union[float, List[float]] = None
|
||||
prompt_strength: float = 0.8
|
||||
preserve_init_image_color_profile = False
|
||||
strict_mask_border = False
|
||||
|
||||
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
hypernetwork_strength: float = 0
|
||||
@ -26,6 +29,35 @@ class GenerateImageRequest(BaseModel):
|
||||
tiling: str = "none" # "none", "x", "y", "xy"
|
||||
|
||||
|
||||
class FilterImageRequest(BaseModel):
|
||||
image: Any = None
|
||||
filter: Union[str, List[str]] = None
|
||||
filter_params: dict = {}
|
||||
|
||||
|
||||
class ModelsData(BaseModel):
|
||||
"""
|
||||
Contains the information related to the models involved in a request.
|
||||
|
||||
- To load a model: set the relative path(s) to the model in `model_paths`. No effect if already loaded.
|
||||
- To unload a model: set the model to `None` in `model_paths`. No effect if already unloaded.
|
||||
|
||||
Models that aren't present in `model_paths` will not be changed.
|
||||
"""
|
||||
|
||||
model_paths: Dict[str, Union[str, None, List[str]]] = None
|
||||
"model_type to string path, or list of string paths"
|
||||
|
||||
model_params: Dict[str, Dict[str, Any]] = {}
|
||||
"model_type to dict of parameters"
|
||||
|
||||
|
||||
class OutputFormatData(BaseModel):
|
||||
output_format: str = "jpeg" # or "png" or "webp"
|
||||
output_quality: int = 75
|
||||
output_lossless: bool = False
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
request_id: str = None
|
||||
session_id: str = "session"
|
||||
@ -40,12 +72,13 @@ class TaskData(BaseModel):
|
||||
use_vae_model: Union[str, List[str]] = None
|
||||
use_hypernetwork_model: Union[str, List[str]] = None
|
||||
use_lora_model: Union[str, List[str]] = None
|
||||
use_controlnet_model: Union[str, List[str]] = None
|
||||
filters: List[str] = []
|
||||
filter_params: Dict[str, Dict[str, Any]] = {}
|
||||
control_filter_to_apply: Union[str, List[str]] = None
|
||||
|
||||
show_only_filtered_image: bool = False
|
||||
block_nsfw: bool = False
|
||||
output_format: str = "jpeg" # or "png" or "webp"
|
||||
output_quality: int = 75
|
||||
output_lossless: bool = False
|
||||
metadata_output_format: str = "txt" # or "json"
|
||||
stream_image_progress: bool = False
|
||||
stream_image_progress_interval: int = 5
|
||||
@ -80,24 +113,39 @@ class Image:
|
||||
}
|
||||
|
||||
|
||||
class Response:
|
||||
class GenerateImageResponse:
|
||||
render_request: GenerateImageRequest
|
||||
task_data: TaskData
|
||||
models_data: ModelsData
|
||||
images: list
|
||||
|
||||
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
|
||||
def __init__(
|
||||
self,
|
||||
render_request: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
output_format: OutputFormatData,
|
||||
images: list,
|
||||
):
|
||||
self.render_request = render_request
|
||||
self.task_data = task_data
|
||||
self.models_data = models_data
|
||||
self.output_format = output_format
|
||||
self.images = images
|
||||
|
||||
def json(self):
|
||||
del self.render_request.init_image
|
||||
del self.render_request.init_image_mask
|
||||
del self.render_request.control_image
|
||||
|
||||
task_data = self.task_data.dict()
|
||||
task_data.update(self.output_format.dict())
|
||||
|
||||
res = {
|
||||
"status": "succeeded",
|
||||
"render_request": self.render_request.dict(),
|
||||
"task_data": self.task_data.dict(),
|
||||
"task_data": task_data,
|
||||
# "models_data": self.models_data.dict(), # haven't migrated the UI to the new format (yet)
|
||||
"output": [],
|
||||
}
|
||||
|
||||
@ -107,5 +155,111 @@ class Response:
|
||||
return res
|
||||
|
||||
|
||||
class FilterImageResponse:
|
||||
request: FilterImageRequest
|
||||
models_data: ModelsData
|
||||
images: list
|
||||
|
||||
def __init__(self, request: FilterImageRequest, models_data: ModelsData, images: list):
|
||||
self.request = request
|
||||
self.models_data = models_data
|
||||
self.images = images
|
||||
|
||||
def json(self):
|
||||
del self.request.image
|
||||
|
||||
res = {
|
||||
"status": "succeeded",
|
||||
"request": self.request.dict(),
|
||||
"models_data": self.models_data.dict(),
|
||||
"output": [],
|
||||
}
|
||||
|
||||
for image in self.images:
|
||||
res["output"].append(image)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def convert_legacy_render_req_to_new(old_req: dict):
|
||||
new_req = dict(old_req)
|
||||
|
||||
# new keys
|
||||
model_paths = new_req["model_paths"] = {}
|
||||
model_params = new_req["model_params"] = {}
|
||||
filters = new_req["filters"] = []
|
||||
filter_params = new_req["filter_params"] = {}
|
||||
|
||||
# move the model info
|
||||
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
|
||||
model_paths["vae"] = old_req.get("use_vae_model")
|
||||
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
|
||||
model_paths["lora"] = old_req.get("use_lora_model")
|
||||
model_paths["controlnet"] = old_req.get("use_controlnet_model")
|
||||
|
||||
model_paths["gfpgan"] = old_req.get("use_face_correction", "")
|
||||
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None
|
||||
|
||||
model_paths["codeformer"] = old_req.get("use_face_correction", "")
|
||||
model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None
|
||||
|
||||
model_paths["realesrgan"] = old_req.get("use_upscale", "")
|
||||
model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None
|
||||
|
||||
model_paths["latent_upscaler"] = old_req.get("use_upscale", "")
|
||||
model_paths["latent_upscaler"] = (
|
||||
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
|
||||
)
|
||||
if "control_filter_to_apply" in old_req:
|
||||
filter_model = old_req["control_filter_to_apply"]
|
||||
model_paths[filter_model] = filter_model
|
||||
|
||||
if old_req.get("block_nsfw"):
|
||||
model_paths["nsfw_checker"] = "nsfw_checker"
|
||||
|
||||
# move the model params
|
||||
if model_paths["stable-diffusion"]:
|
||||
model_params["stable-diffusion"] = {
|
||||
"clip_skip": bool(old_req.get("clip_skip", False)),
|
||||
"convert_to_tensorrt": bool(old_req.get("convert_to_tensorrt", False)),
|
||||
"trt_build_config": old_req.get(
|
||||
"trt_build_config", {"batch_size_range": (1, 1), "dimensions_range": [(768, 1024)]}
|
||||
),
|
||||
}
|
||||
|
||||
# move the filter params
|
||||
if model_paths["realesrgan"]:
|
||||
filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))}
|
||||
if model_paths["latent_upscaler"]:
|
||||
filter_params["latent_upscaler"] = {
|
||||
"prompt": old_req["prompt"],
|
||||
"negative_prompt": old_req.get("negative_prompt"),
|
||||
"seed": int(old_req.get("seed", 42)),
|
||||
"num_inference_steps": int(old_req.get("latent_upscaler_steps", 10)),
|
||||
"guidance_scale": 0,
|
||||
}
|
||||
if model_paths["codeformer"]:
|
||||
filter_params["codeformer"] = {
|
||||
"upscale_faces": bool(old_req.get("codeformer_upscale_faces", True)),
|
||||
"codeformer_fidelity": float(old_req.get("codeformer_fidelity", 0.5)),
|
||||
}
|
||||
|
||||
# set the filters
|
||||
if old_req.get("block_nsfw"):
|
||||
filters.append("nsfw_checker")
|
||||
|
||||
if model_paths["codeformer"]:
|
||||
filters.append("codeformer")
|
||||
elif model_paths["gfpgan"]:
|
||||
filters.append("gfpgan")
|
||||
|
||||
if model_paths["realesrgan"]:
|
||||
filters.append("realesrgan")
|
||||
elif model_paths["latent_upscaler"]:
|
||||
filters.append("latent_upscaler")
|
||||
|
||||
return new_req
|
||||
|
@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from functools import reduce
|
||||
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
|
||||
from numpy import base_repr
|
||||
from sdkit.utils import save_dicts, save_images
|
||||
|
||||
@ -21,6 +21,8 @@ TASK_TEXT_MAPPING = {
|
||||
"seed": "Seed",
|
||||
"use_stable_diffusion_model": "Stable Diffusion model",
|
||||
"clip_skip": "Clip Skip",
|
||||
"use_controlnet_model": "ControlNet model",
|
||||
"control_filter_to_apply": "ControlNet Filter",
|
||||
"use_vae_model": "VAE model",
|
||||
"sampler_name": "Sampler",
|
||||
"width": "Width",
|
||||
@ -114,12 +116,14 @@ def format_file_name(
|
||||
return filename_regex.sub("_", format)
|
||||
|
||||
|
||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||
def save_images_to_disk(
|
||||
images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData
|
||||
):
|
||||
now = time.time()
|
||||
app_config = app.getConfig()
|
||||
folder_format = app_config.get("folder_format", "$id")
|
||||
save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
|
||||
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
||||
metadata_entries = get_metadata_entries_for_request(req, task_data, output_format)
|
||||
file_number = calculate_img_number(save_dir_path, task_data)
|
||||
make_filename = make_filename_callback(
|
||||
app_config.get("filename_format", "$p_$tsb64"),
|
||||
@ -134,9 +138,9 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
output_lossless=task_data.output_lossless,
|
||||
output_format=output_format.output_format,
|
||||
output_quality=output_format.output_quality,
|
||||
output_lossless=output_format.output_lossless,
|
||||
)
|
||||
if task_data.metadata_output_format:
|
||||
for metadata_output_format in task_data.metadata_output_format.split(","):
|
||||
@ -146,7 +150,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
file_format=output_format.output_format,
|
||||
)
|
||||
else:
|
||||
make_filter_filename = make_filename_callback(
|
||||
@ -162,17 +166,17 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
output_lossless=task_data.output_lossless,
|
||||
output_format=output_format.output_format,
|
||||
output_quality=output_format.output_quality,
|
||||
output_lossless=output_format.output_lossless,
|
||||
)
|
||||
save_images(
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
output_lossless=task_data.output_lossless,
|
||||
output_format=output_format.output_format,
|
||||
output_quality=output_format.output_quality,
|
||||
output_lossless=output_format.output_lossless,
|
||||
)
|
||||
if task_data.metadata_output_format:
|
||||
for metadata_output_format in task_data.metadata_output_format.split(","):
|
||||
@ -181,20 +185,21 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
metadata_entries,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
output_format=metadata_output_format,
|
||||
file_format=output_format.output_format,
|
||||
)
|
||||
|
||||
|
||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
metadata = get_printable_request(req, task_data)
|
||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
|
||||
metadata = get_printable_request(req, task_data, output_format)
|
||||
|
||||
# if text, format it in the text format expected by the UI
|
||||
is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",")
|
||||
if is_txt_format:
|
||||
|
||||
def format_value(value):
|
||||
if isinstance(value, list):
|
||||
return ", ".join([ str(it) for it in value ])
|
||||
return ", ".join([str(it) for it in value])
|
||||
return value
|
||||
|
||||
metadata = {
|
||||
@ -208,9 +213,10 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
|
||||
return entries
|
||||
|
||||
|
||||
def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
|
||||
req_metadata = req.dict()
|
||||
task_data_metadata = task_data.dict()
|
||||
task_data_metadata.update(output_format.dict())
|
||||
|
||||
app_config = app.getConfig()
|
||||
using_diffusers = app_config.get("test_diffusers", False)
|
||||
@ -224,6 +230,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
metadata[key] = task_data_metadata[key]
|
||||
elif key == "use_embedding_models" and using_diffusers:
|
||||
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
||||
|
||||
def scan_directory(directory_path: str):
|
||||
used_embeddings = []
|
||||
for entry in os.scandir(directory_path):
|
||||
@ -232,15 +239,18 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
if entry_extension not in embeddings_extensions:
|
||||
continue
|
||||
|
||||
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])")
|
||||
embedding_name_regex = regex.compile(
|
||||
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
|
||||
)
|
||||
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
||||
used_embeddings.append(entry.path)
|
||||
elif entry.is_dir():
|
||||
used_embeddings.extend(scan_directory(entry.path))
|
||||
return used_embeddings
|
||||
|
||||
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
||||
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
|
||||
|
||||
|
||||
# Clean up the metadata
|
||||
if req.init_image is None and "prompt_strength" in metadata:
|
||||
del metadata["prompt_strength"]
|
||||
@ -252,9 +262,13 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
del metadata["lora_alpha"]
|
||||
if task_data.use_upscale != "latent_upscaler" and "latent_upscaler_steps" in metadata:
|
||||
del metadata["latent_upscaler_steps"]
|
||||
if task_data.use_controlnet_model is None and "control_filter_to_apply" in metadata:
|
||||
del metadata["control_filter_to_apply"]
|
||||
|
||||
if not using_diffusers:
|
||||
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata):
|
||||
for key in (
|
||||
x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps", "use_controlnet_model", "control_filter_to_apply"] if x in metadata
|
||||
):
|
||||
del metadata[key]
|
||||
|
||||
return metadata
|
||||
|
113
ui/index.html
113
ui/index.html
@ -17,6 +17,7 @@
|
||||
<link rel="stylesheet" href="/media/css/searchable-models.css">
|
||||
<link rel="stylesheet" href="/media/css/image-modal.css">
|
||||
<link rel="stylesheet" href="/media/css/plugins.css">
|
||||
<link rel="stylesheet" href="/media/css/animations.css">
|
||||
<link rel="manifest" href="/media/manifest.webmanifest">
|
||||
<script src="/media/js/jquery-3.6.1.min.js"></script>
|
||||
<script src="/media/js/jquery-confirm.min.js"></script>
|
||||
@ -31,7 +32,7 @@
|
||||
<h1>
|
||||
<img id="logo_img" src="/media/images/icon-512x512.png" >
|
||||
Easy Diffusion
|
||||
<small><span id="version">v2.5.46</span> <span id="updateBranchLabel"></span></small>
|
||||
<small><span id="version">v2.5.48</span> <span id="updateBranchLabel"></span></small>
|
||||
</h1>
|
||||
</div>
|
||||
<div id="server-status">
|
||||
@ -82,8 +83,8 @@
|
||||
<label for="init_image">Initial Image (img2img) <small>(optional)</small> </label>
|
||||
|
||||
<div id="init_image_preview_container" class="image_preview_container">
|
||||
<div id="init_image_wrapper">
|
||||
<img id="init_image_preview" src="" crossorigin="anonymous" />
|
||||
<div id="init_image_wrapper" class="preview_image_wrapper">
|
||||
<img id="init_image_preview" class="image_preview" src="" crossorigin="anonymous" />
|
||||
<span id="init_image_size_box" class="img_bottom_label"></span>
|
||||
<button class="init_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
|
||||
</div>
|
||||
@ -108,6 +109,7 @@
|
||||
</div>
|
||||
|
||||
<div id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Preserve color profile <small>(helps during inpainting)</small></label></div>
|
||||
<div id="strict_mask_border_setting" class="pl-5"><input id="strict_mask_border" name="strict_mask_border" type="checkbox"> <label for="strict_mask_border">Strict Mask Border <small>(won't modify outside the mask, but the mask border might be visible)</small></label></div>
|
||||
|
||||
</div>
|
||||
|
||||
@ -139,12 +141,19 @@
|
||||
<div><table>
|
||||
<tr><b class="settings-subheader">Image Settings</b></tr>
|
||||
<tr class="pl-5"><td><label for="seed">Seed:</label></td><td><input id="seed" name="seed" size="10" value="0" onkeypress="preventNonNumericalInput(event)"> <input id="random_seed" name="random_seed" type="checkbox" checked><label for="random_seed">Random</label></td></tr>
|
||||
<tr class="pl-5"><td><label for="num_outputs_total">Number of Images:</label></td><td><input id="num_outputs_total" name="num_outputs_total" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label><small>(total)</small></label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label for="num_outputs_parallel"><small>(in parallel)</small></label></td></tr>
|
||||
<tr class="pl-5"><td><label for="num_outputs_total">Number of Images:</label></td><td><input id="num_outputs_total" name="num_outputs_total" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label><small>(total)</small></label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label id="num_outputs_parallel_label" for="num_outputs_parallel"><small>(in parallel)</small></label></td></tr>
|
||||
<tr class="pl-5"><td><label for="stable_diffusion_model">Model:</label></td><td class="model-input">
|
||||
<input id="stable_diffusion_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<button id="reload-models" class="secondaryButton reloadModels"><i class='fa-solid fa-rotate'></i></button>
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
|
||||
</td></tr>
|
||||
<tr class="pl-5 displayNone" id="enable_trt_config">
|
||||
<td><label for="convert_to_tensorrt">Enable TensorRT:</label></td>
|
||||
<td class="diffusers-restart-needed">
|
||||
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
|
||||
<!-- <label><small>Takes upto 20 mins the first time</small></label> -->
|
||||
</td>
|
||||
</tr>
|
||||
<tr class="pl-5 displayNone" id="clip_skip_config">
|
||||
<td><label for="clip_skip">Clip Skip:</label></td>
|
||||
<td class="diffusers-restart-needed">
|
||||
@ -152,6 +161,60 @@
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Clip-Skip" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about Clip Skip</span></i></a>
|
||||
</td>
|
||||
</tr>
|
||||
<tr id="controlnet_model_container" class="pl-5"><td><label for="controlnet_model">ControlNet Image:</label></td><td>
|
||||
<div id="control_image_wrapper" class="preview_image_wrapper">
|
||||
<img id="control_image_preview" class="image_preview" src="" crossorigin="anonymous" />
|
||||
<span id="control_image_size_box" class="img_bottom_label"></span>
|
||||
<button class="control_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
|
||||
</div>
|
||||
<input id="control_image" name="control_image" type="file" />
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/ControlNet" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about ControlNets</span></i></a>
|
||||
<div id="controlnet_config" class="displayNone">
|
||||
<label><small>Filter to apply:</small></label>
|
||||
<select id="control_image_filter">
|
||||
<option value="">None</option>
|
||||
<optgroup label="Pose">
|
||||
<option value="openpose">OpenPose (*)</option>
|
||||
<option value="openpose_face">OpenPose face</option>
|
||||
<option value="openpose_faceonly">OpenPose face-only</option>
|
||||
<option value="openpose_hand">OpenPose hand</option>
|
||||
<option value="openpose_full">OpenPose full</option>
|
||||
</optgroup>
|
||||
<optgroup label="Outline">
|
||||
<option value="canny">Canny (*)</option>
|
||||
<option value="mlsd">Straight lines</option>
|
||||
<option value="scribble_hed">Scribble hed (*)</option>
|
||||
<option value="scribble_hedsafe">Scribble hedsafe</option>
|
||||
<option value="scribble_pidinet">Scribble pidinet</option>
|
||||
<option value="scribble_pidsafe">Scribble pidsafe</option>
|
||||
<option value="softedge_hed">Softedge hed</option>
|
||||
<option value="softedge_hedsafe">Softedge hedsafe</option>
|
||||
<option value="softedge_pidinet">Softedge pidinet</option>
|
||||
<option value="softedge_pidsafe">Softedge pidsafe</option>
|
||||
</optgroup>
|
||||
<optgroup label="Depth">
|
||||
<option value="normal_bae">Normal bae (*)</option>
|
||||
<option value="depth_midas">Depth midas</option>
|
||||
<option value="depth_zoe">Depth zoe</option>
|
||||
<option value="depth_leres">Depth leres</option>
|
||||
<option value="depth_leres++">Depth leres++</option>
|
||||
</optgroup>
|
||||
<optgroup label="Line art">
|
||||
<option value="lineart_coarse">Lineart coarse</option>
|
||||
<option value="lineart_realistic">Lineart realistic</option>
|
||||
<option value="lineart_anime">Lineart anime</option>
|
||||
</optgroup>
|
||||
<optgroup label="Misc">
|
||||
<option value="shuffle">Shuffle</option>
|
||||
<option value="segment">Segment</option>
|
||||
</optgroup>
|
||||
</select>
|
||||
<br/>
|
||||
<label for="controlnet_model"><small>Model:</small></label> <input id="controlnet_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<br/>
|
||||
<label><small>Will download the necessary models, the first time.</small></label>
|
||||
</div>
|
||||
</td></tr>
|
||||
<tr class="pl-5"><td><label for="vae_model">Custom VAE:</label></td><td>
|
||||
<input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
|
||||
@ -182,15 +245,15 @@
|
||||
</select>
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
|
||||
</td></tr>
|
||||
<tr class="pl-5"><td><label>Image Size: </label></td><td>
|
||||
<tr class="pl-5"><td><label>Image Size: </label></td><td id="image-size-options">
|
||||
<select id="width" name="width" value="512">
|
||||
<option value="128">128 (*)</option>
|
||||
<option value="128">128</option>
|
||||
<option value="192">192</option>
|
||||
<option value="256">256 (*)</option>
|
||||
<option value="256">256</option>
|
||||
<option value="320">320</option>
|
||||
<option value="384">384</option>
|
||||
<option value="448">448</option>
|
||||
<option value="512" selected>512 (*)</option>
|
||||
<option value="512" selected="">512 (*)</option>
|
||||
<option value="576">576</option>
|
||||
<option value="640">640</option>
|
||||
<option value="704">704</option>
|
||||
@ -205,14 +268,15 @@
|
||||
<option value="2048">2048</option>
|
||||
</select>
|
||||
<label for="width"><small>(width)</small></label>
|
||||
<span id="swap-width-height" class="clickable smallButton" style="margin-left: 2px; margin-right:2px;"><i class="fa-solid fa-right-left"><span class="simple-tooltip top-left"> Swap width and height </span></i></span>
|
||||
<select id="height" name="height" value="512">
|
||||
<option value="128">128 (*)</option>
|
||||
<option value="128">128</option>
|
||||
<option value="192">192</option>
|
||||
<option value="256">256 (*)</option>
|
||||
<option value="256">256</option>
|
||||
<option value="320">320</option>
|
||||
<option value="384">384</option>
|
||||
<option value="448">448</option>
|
||||
<option value="512" selected>512 (*)</option>
|
||||
<option value="512" selected="">512 (*)</option>
|
||||
<option value="576">576</option>
|
||||
<option value="640">640</option>
|
||||
<option value="704">704</option>
|
||||
@ -227,6 +291,22 @@
|
||||
<option value="2048">2048</option>
|
||||
</select>
|
||||
<label for="height"><small>(height)</small></label>
|
||||
<div id="recent-resolutions-container">
|
||||
<span id="recent-resolutions-button" class="clickable"><i class="fa-solid fa-sliders"><span class="simple-tooltip top-left"> Advanced sizes </span></i></span>
|
||||
<div id="recent-resolutions-popup" class="displayNone">
|
||||
<small>Custom size:</small><br>
|
||||
<input id="custom-width" name="custom-width" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)">
|
||||
×
|
||||
<input id="custom-height" name="custom-height" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)"><br>
|
||||
|
||||
<small>Enlarge:</small><br>
|
||||
<div id="enlarge-buttons"><button id="enlarge15" class="tertiaryButton smallButton">×1.5</button> <button id="enlarge2" class="tertiaryButton smallButton">×2</button> <button id="enlarge3" class="tertiaryButton smallButton">×3</button></div>
|
||||
|
||||
<small>Recently used:</small><br>
|
||||
<div id="recent-resolution-list">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="small_image_warning" class="displayNone">Small image sizes can cause bad image quality</div>
|
||||
</td></tr>
|
||||
<tr class="pl-5"><td><label for="num_inference_steps">Inference Steps:</label></td><td> <input id="num_inference_steps" name="num_inference_steps" type="number" min="1" step="1" style="width: 42pt" value="25" onkeypress="preventNonNumericalInput(event)"></td></tr>
|
||||
@ -359,6 +439,13 @@
|
||||
<div class="parameters-table" id="system-settings-table"></div>
|
||||
<br/>
|
||||
<button id="save-system-settings-btn" class="primaryButton">Save</button>
|
||||
<div id="install-extras-container" class="displayNone">
|
||||
<br/>
|
||||
<div id="install-extras">
|
||||
<h3><i class="fa fa-cubes-stacked"></i> Optional Packages</h3>
|
||||
<div class="parameters-table" id="system-settings-install-extras-table"></div>
|
||||
</div>
|
||||
</div>
|
||||
<br/><br/>
|
||||
<div id="share-easy-diffusion">
|
||||
<h3><i class="fa fa-user-group"></i> Share Easy Diffusion</h3>
|
||||
@ -669,10 +756,10 @@ async function init() {
|
||||
events: {
|
||||
statusChange: setServerStatus,
|
||||
idle: onIdle,
|
||||
ping: tunnelUpdate
|
||||
ping: onPing
|
||||
}
|
||||
})
|
||||
splashScreen()
|
||||
// splashScreen()
|
||||
|
||||
// load models again, but scan for malicious this time
|
||||
await getModels(true)
|
||||
|
68
ui/media/css/animations.css
Normal file
68
ui/media/css/animations.css
Normal file
@ -0,0 +1,68 @@
|
||||
@keyframes ldio-8f673ktaleu-1 {
|
||||
0% { transform: rotate(0deg) }
|
||||
50% { transform: rotate(-45deg) }
|
||||
100% { transform: rotate(0deg) }
|
||||
}
|
||||
@keyframes ldio-8f673ktaleu-2 {
|
||||
0% { transform: rotate(180deg) }
|
||||
50% { transform: rotate(225deg) }
|
||||
100% { transform: rotate(180deg) }
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(2) {
|
||||
transform: translate(-15px,0);
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(2) div {
|
||||
position: absolute;
|
||||
top: 20px;
|
||||
left: 20px;
|
||||
width: 60px;
|
||||
height: 30px;
|
||||
border-radius: 60px 60px 0 0;
|
||||
background: #f3b72e;
|
||||
animation: ldio-8f673ktaleu-1 1s linear infinite;
|
||||
transform-origin: 30px 30px
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(2) div:nth-child(2) {
|
||||
animation: ldio-8f673ktaleu-2 1s linear infinite
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(2) div:nth-child(3) {
|
||||
transform: rotate(-90deg);
|
||||
animation: none;
|
||||
}@keyframes ldio-8f673ktaleu-3 {
|
||||
0% { transform: translate(95px,0); opacity: 0 }
|
||||
20% { opacity: 1 }
|
||||
100% { transform: translate(35px,0); opacity: 1 }
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(1) {
|
||||
display: block;
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(1) div {
|
||||
position: absolute;
|
||||
top: 46px;
|
||||
left: -4px;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #3869c5;
|
||||
animation: ldio-8f673ktaleu-3 1s linear infinite
|
||||
}
|
||||
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(1) { animation-delay: -0.67s }
|
||||
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(2) { animation-delay: -0.33s }
|
||||
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(3) { animation-delay: 0s }
|
||||
.loadingio-spinner-bean-eater-x0y3u8qky4n {
|
||||
width: 58px;
|
||||
height: 58px;
|
||||
display: inline-block;
|
||||
overflow: hidden;
|
||||
background: none;
|
||||
}
|
||||
.ldio-8f673ktaleu {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
position: relative;
|
||||
transform: translateZ(0) scale(0.58);
|
||||
backface-visibility: hidden;
|
||||
transform-origin: 0 0; /* see note above */
|
||||
}
|
||||
.ldio-8f673ktaleu div { box-sizing: content-box; }
|
||||
/* generated by https://loading.io/ */
|
@ -794,7 +794,7 @@ div.img-preview img {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
#init_image_preview_container:not(.has-image) #init_image_wrapper,
|
||||
#init_image_preview_container:not(.has-image) .preview_image_wrapper,
|
||||
#init_image_preview_container:not(.has-image) #inpaint_button_container {
|
||||
display: none;
|
||||
}
|
||||
@ -831,14 +831,14 @@ div.img-preview img {
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
#init_image_wrapper {
|
||||
.preview_image_wrapper {
|
||||
grid-row: span 3;
|
||||
position: relative;
|
||||
width: fit-content;
|
||||
max-height: 150px;
|
||||
}
|
||||
|
||||
#init_image_preview {
|
||||
.image_preview {
|
||||
max-height: 150px;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
@ -1753,3 +1753,82 @@ body.wait-pause {
|
||||
content: "Please restart Easy Diffusion!";
|
||||
font-size: 10pt;
|
||||
}
|
||||
|
||||
input#custom-width, input#custom-height {
|
||||
width: 47pt;
|
||||
}
|
||||
|
||||
div#recent-resolutions-container {
|
||||
position: relative;
|
||||
display:inline-block;
|
||||
}
|
||||
|
||||
div#recent-resolutions-popup {
|
||||
position: absolute;
|
||||
right: 0px;
|
||||
margin: 3px;
|
||||
padding: 0.2em 1em 0.4em 1em;
|
||||
z-index: 1;
|
||||
background: var(--background-color3);
|
||||
border-radius: 4px;
|
||||
box-shadow: 0 20px 28px 0 rgba(0, 0, 0, 0.15), 0 6px 20px 0 rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
div#recent-resolutions-popup small {
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
td#image-size-options small {
|
||||
margin-right: 0px !important;
|
||||
}
|
||||
|
||||
td#image-size-options {
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
div#recent-resolution-list {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
div#enlarge-buttons {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.clickable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.imgContainer .spinner {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
|
||||
background: var(--background-color3);
|
||||
opacity: 0.95;
|
||||
border-radius: 5px;
|
||||
padding: 4pt;
|
||||
border: 1px solid var(--button-color);
|
||||
box-shadow: 0px 0px 4px black;
|
||||
}
|
||||
|
||||
.imgContainer .spinnerStatus {
|
||||
font-size: 10pt;
|
||||
}
|
||||
|
||||
#controlnet_model_container small {
|
||||
color: var(--text-color)
|
||||
}
|
||||
#control_image {
|
||||
width: 130pt;
|
||||
}
|
||||
#controlnet_model {
|
||||
width: 77%;
|
||||
}
|
||||
|
||||
/* hack for fixing Image Modifier Improvements plugin */
|
||||
#imageTagPopupContainer {
|
||||
position: absolute;
|
||||
}
|
@ -1047,7 +1047,9 @@
|
||||
}
|
||||
}
|
||||
class FilterTask extends Task {
|
||||
constructor(options = {}) {}
|
||||
constructor(options = {}) {
|
||||
super(options)
|
||||
}
|
||||
/** Send current task to server.
|
||||
* @param {*} [timeout=-1] Optional timeout value in ms
|
||||
* @returns the response from the render request.
|
||||
@ -1055,9 +1057,27 @@
|
||||
*/
|
||||
async post(timeout = -1) {
|
||||
let jsonResponse = await super.post("/filter", timeout)
|
||||
//this._setId(jsonResponse.task)
|
||||
if (typeof jsonResponse?.task !== "number") {
|
||||
console.warn("Endpoint error response: ", jsonResponse)
|
||||
const event = Object.assign({ task: this }, jsonResponse)
|
||||
await eventSource.fireEvent(EVENT_UNEXPECTED_RESPONSE, event)
|
||||
if ("continueWith" in event) {
|
||||
jsonResponse = await Promise.resolve(event.continueWith)
|
||||
}
|
||||
if (typeof jsonResponse?.task !== "number") {
|
||||
const err = new Error(jsonResponse?.detail || "Endpoint response does not contains a task ID.")
|
||||
this.abort(err)
|
||||
throw err
|
||||
}
|
||||
}
|
||||
this._setId(jsonResponse.task)
|
||||
if (jsonResponse.stream) {
|
||||
this.streamUrl = jsonResponse.stream
|
||||
}
|
||||
this._setStatus(TaskStatus.waiting)
|
||||
return jsonResponse
|
||||
}
|
||||
checkReqBody() {}
|
||||
enqueue(progressCallback) {
|
||||
return Task.enqueueNew(this, FilterTask, progressCallback)
|
||||
}
|
||||
@ -1068,6 +1088,65 @@
|
||||
if (this.isStopped) {
|
||||
return
|
||||
}
|
||||
|
||||
this._setStatus(TaskStatus.pending)
|
||||
progressCallback?.call(this, { reqBody: this._reqBody })
|
||||
Object.freeze(this._reqBody)
|
||||
|
||||
// Post task request to backend
|
||||
let renderRes = undefined
|
||||
try {
|
||||
renderRes = yield this.post()
|
||||
yield progressCallback?.call(this, { renderResponse: renderRes })
|
||||
} catch (e) {
|
||||
yield progressCallback?.call(this, { detail: e.message })
|
||||
throw e
|
||||
}
|
||||
|
||||
try {
|
||||
// Wait for task to start on server.
|
||||
yield this.waitUntil({
|
||||
callback: function() {
|
||||
return progressCallback?.call(this, {})
|
||||
},
|
||||
status: TaskStatus.processing,
|
||||
})
|
||||
} catch (e) {
|
||||
this.abort(err)
|
||||
throw e
|
||||
}
|
||||
|
||||
// Task started!
|
||||
// Open the reader.
|
||||
const reader = this.reader
|
||||
const task = this
|
||||
reader.onError = function(response) {
|
||||
if (progressCallback) {
|
||||
task.abort(new Error(response.statusText))
|
||||
return progressCallback.call(task, { response, reader })
|
||||
}
|
||||
return Task.prototype.onError.call(task, response)
|
||||
}
|
||||
yield progressCallback?.call(this, { reader })
|
||||
|
||||
//Start streaming the results.
|
||||
const streamGenerator = reader.open()
|
||||
let value = undefined
|
||||
let done = undefined
|
||||
yield progressCallback?.call(this, { stream: streamGenerator })
|
||||
do {
|
||||
;({ value, done } = yield streamGenerator.next())
|
||||
if (typeof value !== "object") {
|
||||
continue
|
||||
}
|
||||
if (value.status !== undefined) {
|
||||
yield progressCallback?.call(this, value)
|
||||
if (value.status === "succeeded" || value.status === "failed") {
|
||||
done = true
|
||||
}
|
||||
}
|
||||
} while (!done)
|
||||
return value
|
||||
}
|
||||
static start(task, progressCallback) {
|
||||
if (typeof task !== "object") {
|
||||
|
@ -626,6 +626,7 @@ class ImageEditor {
|
||||
.getImageData(0, 0, this.width, this.height)
|
||||
.data.some((channel) => channel !== 0)
|
||||
maskSetting.checked = !is_blank
|
||||
maskSetting.dispatchEvent(new Event("change"))
|
||||
}
|
||||
this.hide()
|
||||
}
|
||||
|
@ -5,6 +5,9 @@ const MIN_GPUS_TO_SHOW_SELECTION = 2
|
||||
const IMAGE_REGEX = new RegExp("data:image/[A-Za-z]+;base64")
|
||||
const htmlTaskMap = new WeakMap()
|
||||
|
||||
const spinnerPacmanHtml =
|
||||
'<div class="loadingio-spinner-bean-eater-x0y3u8qky4n"><div class="ldio-8f673ktaleu"><div><div></div><div></div><div></div></div><div><div></div><div></div><div></div></div></div></div>'
|
||||
|
||||
const taskConfigSetup = {
|
||||
taskConfig: {
|
||||
seed: { value: ({ seed }) => seed, label: "Seed" },
|
||||
@ -46,6 +49,7 @@ const taskConfigSetup = {
|
||||
use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
|
||||
lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
|
||||
preserve_init_image_color_profile: "Preserve Color Profile",
|
||||
strict_mask_border: "Strict Mask Border",
|
||||
},
|
||||
pluginTaskConfig: {},
|
||||
getCSSKey: (key) =>
|
||||
@ -74,14 +78,30 @@ let randomSeedField = document.querySelector("#random_seed")
|
||||
let seedField = document.querySelector("#seed")
|
||||
let widthField = document.querySelector("#width")
|
||||
let heightField = document.querySelector("#height")
|
||||
let customWidthField = document.querySelector("#custom-width")
|
||||
let customHeightField = document.querySelector("#custom-height")
|
||||
let recentResolutionsButton = document.querySelector("#recent-resolutions-button")
|
||||
let recentResolutionsPopup = document.querySelector("#recent-resolutions-popup")
|
||||
let recentResolutionList = document.querySelector("#recent-resolution-list")
|
||||
let enlarge15Button = document.querySelector("#enlarge15")
|
||||
let enlarge2Button = document.querySelector("#enlarge2")
|
||||
let enlarge3Button = document.querySelector("#enlarge3")
|
||||
let swapWidthHeightButton = document.querySelector("#swap-width-height")
|
||||
let smallImageWarning = document.querySelector("#small_image_warning")
|
||||
let initImageSelector = document.querySelector("#init_image")
|
||||
let initImagePreview = document.querySelector("#init_image_preview")
|
||||
let initImageSizeBox = document.querySelector("#init_image_size_box")
|
||||
let maskImageSelector = document.querySelector("#mask")
|
||||
let maskImagePreview = document.querySelector("#mask_preview")
|
||||
let controlImageSelector = document.querySelector("#control_image")
|
||||
let controlImagePreview = document.querySelector("#control_image_preview")
|
||||
let controlImageClearBtn = document.querySelector(".control_image_clear")
|
||||
let controlImageContainer = document.querySelector("#control_image_wrapper")
|
||||
let controlImageFilterField = document.querySelector("#control_image_filter")
|
||||
let applyColorCorrectionField = document.querySelector("#apply_color_correction")
|
||||
let strictMaskBorderField = document.querySelector("#strict_mask_border")
|
||||
let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting")
|
||||
let strictMaskBorderSetting = document.querySelector("#strict_mask_border_setting")
|
||||
let promptStrengthSlider = document.querySelector("#prompt_strength_slider")
|
||||
let promptStrengthField = document.querySelector("#prompt_strength")
|
||||
let samplerField = document.querySelector("#sampler_name")
|
||||
@ -99,6 +119,7 @@ let codeformerFidelityField = document.querySelector("#codeformer_fidelity")
|
||||
let stableDiffusionModelField = new ModelDropdown(document.querySelector("#stable_diffusion_model"), "stable-diffusion")
|
||||
let clipSkipField = document.querySelector("#clip_skip")
|
||||
let tilingField = document.querySelector("#tiling")
|
||||
let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false)
|
||||
let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None")
|
||||
let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None")
|
||||
let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider")
|
||||
@ -160,6 +181,7 @@ let imagePreviewContent = document.querySelector("#preview-content")
|
||||
let undoButton = document.querySelector("#undo")
|
||||
let undoBuffer = []
|
||||
const UNDO_LIMIT = 20
|
||||
const MAX_IMG_UNDO_ENTRIES = 5
|
||||
|
||||
let loraModels = []
|
||||
|
||||
@ -271,24 +293,24 @@ function setServerStatus(event) {
|
||||
// e : MouseEvent
|
||||
// prompt : Text to be shown as prompt. Should be a question to which "yes" is a good answer.
|
||||
// fn : function to be called if the user confirms the dialog or has the shift key pressed
|
||||
// allowSkip: Allow skipping the dialog using the shift key or the confirm_dangerous_actions setting (default: true)
|
||||
//
|
||||
// If the user had the shift key pressed while clicking, the function fn will be executed.
|
||||
// If the setting "confirm_dangerous_actions" in the system settings is disabled, the function
|
||||
// fn will be executed.
|
||||
// Otherwise, a confirmation dialog is shown. If the user confirms, the function fn will also
|
||||
// be executed.
|
||||
function shiftOrConfirm(e, prompt, fn) {
|
||||
function shiftOrConfirm(e, prompt, fn, allowSkip = true) {
|
||||
e.stopPropagation()
|
||||
if (e.shiftKey || !confirmDangerousActionsField.checked) {
|
||||
let tip = allowSkip
|
||||
? '<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>'
|
||||
: ""
|
||||
if (allowSkip && (e.shiftKey || !confirmDangerousActionsField.checked)) {
|
||||
fn(e)
|
||||
} else {
|
||||
confirm(
|
||||
'<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>',
|
||||
prompt,
|
||||
() => {
|
||||
fn(e)
|
||||
}
|
||||
)
|
||||
confirm(tip, prompt, () => {
|
||||
fn(e)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,6 +434,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
</div>
|
||||
<button class="imgPreviewItemClearBtn image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
|
||||
<span class="img_bottom_label"></span>
|
||||
<div class="spinner displayNone"><center>${spinnerPacmanHtml}</center><div class="spinnerStatus"></div></div>
|
||||
</div>
|
||||
`
|
||||
outputContainer.appendChild(imageItemElem)
|
||||
@ -488,6 +511,8 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
const imageSeedLabel = imageItemElem.querySelector(".imgSeedLabel")
|
||||
imageSeedLabel.innerText = "Seed: " + req.seed
|
||||
|
||||
const imageUndoBuffer = []
|
||||
const imageRedoBuffer = []
|
||||
let buttons = [
|
||||
{ text: "Use as Input", on_click: onUseAsInputClick },
|
||||
[
|
||||
@ -505,8 +530,10 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
{ text: "Make Similar Images", on_click: onMakeSimilarClick },
|
||||
{ text: "Draw another 25 steps", on_click: onContinueDrawingClick },
|
||||
[
|
||||
{ text: "Upscale", on_click: onUpscaleClick, filter: (req, img) => !req.use_upscale },
|
||||
{ text: "Fix Faces", on_click: onFixFacesClick, filter: (req, img) => !req.use_face_correction },
|
||||
{ html: '<i class="fa-solid fa-undo"></i> Undo', on_click: onUndoFilter },
|
||||
{ html: '<i class="fa-solid fa-redo"></i> Redo', on_click: onRedoFilter },
|
||||
{ text: "Upscale", on_click: onUpscaleClick },
|
||||
{ text: "Fix Faces", on_click: onFixFacesClick },
|
||||
],
|
||||
]
|
||||
|
||||
@ -515,6 +542,14 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
|
||||
const imgItemInfo = imageItemElem.querySelector(".imgItemInfo")
|
||||
const img = imageItemElem.querySelector("img")
|
||||
const spinner = imageItemElem.querySelector(".spinner")
|
||||
const spinnerStatus = imageItemElem.querySelector(".spinnerStatus")
|
||||
const tools = {
|
||||
spinner: spinner,
|
||||
spinnerStatus: spinnerStatus,
|
||||
undoBuffer: imageUndoBuffer,
|
||||
redoBuffer: imageRedoBuffer,
|
||||
}
|
||||
const createButton = function(btnInfo) {
|
||||
if (Array.isArray(btnInfo)) {
|
||||
const wrapper = document.createElement("div")
|
||||
@ -540,8 +575,16 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
|
||||
if (btnInfo.on_click || !isLabel) {
|
||||
newButton.addEventListener("click", function(event) {
|
||||
btnInfo.on_click(req, img, event)
|
||||
btnInfo.on_click.bind(newButton)(req, img, event, tools)
|
||||
})
|
||||
if (btnInfo.on_click === onUndoFilter) {
|
||||
tools["undoButton"] = newButton
|
||||
newButton.classList.add("displayNone")
|
||||
}
|
||||
if (btnInfo.on_click === onRedoFilter) {
|
||||
tools["redoButton"] = newButton
|
||||
newButton.classList.add("displayNone")
|
||||
}
|
||||
}
|
||||
|
||||
if (btnInfo.class !== undefined) {
|
||||
@ -656,16 +699,86 @@ function enqueueImageVariationTask(req, img, reqDiff) {
|
||||
createTask(newTaskRequest)
|
||||
}
|
||||
|
||||
function onUpscaleClick(req, img) {
|
||||
enqueueImageVariationTask(req, img, {
|
||||
use_upscale: upscaleModelField.value,
|
||||
function applyInlineFilter(filterName, path, filterParams, img, statusText, tools) {
|
||||
const filterReq = {
|
||||
image: img.src,
|
||||
filter: filterName,
|
||||
model_paths: {},
|
||||
filter_params: filterParams,
|
||||
output_format: outputFormatField.value,
|
||||
output_quality: parseInt(outputQualityField.value),
|
||||
output_lossless: outputLosslessField.checked,
|
||||
}
|
||||
filterReq.model_paths[filterName] = path
|
||||
|
||||
tools.spinnerStatus.innerText = statusText
|
||||
tools.spinner.classList.remove("displayNone")
|
||||
|
||||
SD.filter(filterReq, (e) => {
|
||||
if (e.status === "succeeded") {
|
||||
let prevImg = img.src
|
||||
img.src = e.output[0]
|
||||
tools.spinner.classList.add("displayNone")
|
||||
|
||||
if (prevImg.length > 0) {
|
||||
tools.undoBuffer.push(prevImg)
|
||||
tools.redoBuffer = []
|
||||
|
||||
if (tools.undoBuffer.length > MAX_IMG_UNDO_ENTRIES) {
|
||||
let n = tools.undoBuffer.length
|
||||
tools.undoBuffer.splice(0, n - MAX_IMG_UNDO_ENTRIES)
|
||||
}
|
||||
|
||||
tools.undoButton.classList.remove("displayNone")
|
||||
tools.redoButton.classList.add("displayNone")
|
||||
}
|
||||
} else if (e.status == "failed") {
|
||||
alert("Error running upscale: " + e.detail)
|
||||
tools.spinner.classList.add("displayNone")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function onFixFacesClick(req, img) {
|
||||
enqueueImageVariationTask(req, img, {
|
||||
use_face_correction: gfpganModelField.value,
|
||||
})
|
||||
function moveImageBetweenBuffers(img, fromBuffer, toBuffer, fromButton, toButton) {
|
||||
if (fromBuffer.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
let src = fromBuffer.pop()
|
||||
if (src.length > 0) {
|
||||
toBuffer.push(img.src)
|
||||
img.src = src
|
||||
}
|
||||
|
||||
if (fromBuffer.length === 0) {
|
||||
fromButton.classList.add("displayNone")
|
||||
}
|
||||
if (toBuffer.length > 0) {
|
||||
toButton.classList.remove("displayNone")
|
||||
}
|
||||
}
|
||||
|
||||
function onUndoFilter(req, img, e, tools) {
|
||||
moveImageBetweenBuffers(img, tools.undoBuffer, tools.redoBuffer, tools.undoButton, tools.redoButton)
|
||||
}
|
||||
|
||||
function onRedoFilter(req, img, e, tools) {
|
||||
moveImageBetweenBuffers(img, tools.redoBuffer, tools.undoBuffer, tools.redoButton, tools.undoButton)
|
||||
}
|
||||
|
||||
function onUpscaleClick(req, img, e, tools) {
|
||||
let path = upscaleModelField.value
|
||||
let scale = parseInt(upscaleAmountField.value)
|
||||
let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler"
|
||||
let statusText = "Upscaling by " + scale + "x using " + filterName
|
||||
applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools)
|
||||
}
|
||||
|
||||
function onFixFacesClick(req, img, e, tools) {
|
||||
let path = gfpganModelField.value
|
||||
let filterName = path.toLowerCase().includes("gfpgan") ? "gfpgan" : "codeformer"
|
||||
let statusText = "Fixing faces with " + filterName
|
||||
applyInlineFilter(filterName, path, {}, img, statusText, tools)
|
||||
}
|
||||
|
||||
function onContinueDrawingClick(req, img) {
|
||||
@ -909,7 +1022,9 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
|
||||
<a href="https://www.ibm.com/docs/en/opw/8.2.0?topic=tuning-optional-increasing-paging-file-size-windows-computers" target="_blank">Windows</a> or
|
||||
<a href="https://linuxhint.com/increase-swap-space-linux/" target="_blank">Linux</a>.<br/>
|
||||
3. Try restarting your computer.<br/>`
|
||||
} else if (msg.includes("RuntimeError: output with shape [320, 320] doesn't match the broadcast shape")) {
|
||||
} else if (
|
||||
msg.includes("RuntimeError: output with shape [320, 320] doesn't match the broadcast shape")
|
||||
) {
|
||||
msg += `<br/><br/>
|
||||
<b>Reason</b>: You tried to use a LORA that was trained for a different Stable Diffusion model version!
|
||||
<br/><br/>
|
||||
@ -1237,9 +1352,25 @@ function createTask(task) {
|
||||
|
||||
function getCurrentUserRequest() {
|
||||
const numOutputsTotal = parseInt(numOutputsTotalField.value)
|
||||
const numOutputsParallel = parseInt(numOutputsParallelField.value)
|
||||
let numOutputsParallel = parseInt(numOutputsParallelField.value)
|
||||
const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value)
|
||||
|
||||
// if (
|
||||
// testDiffusers.checked &&
|
||||
// document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall" &&
|
||||
// document.querySelector("#convert_to_tensorrt").checked
|
||||
// ) {
|
||||
// // TRT enabled
|
||||
|
||||
// numOutputsParallel = 1 // force 1 parallel
|
||||
// }
|
||||
|
||||
// clamp to multiple of 8
|
||||
let width = parseInt(widthField.value)
|
||||
let height = parseInt(heightField.value)
|
||||
width = width - (width % 8)
|
||||
height = height - (height % 8)
|
||||
|
||||
const newTask = {
|
||||
batchesDone: 0,
|
||||
numOutputsTotal: numOutputsTotal,
|
||||
@ -1252,8 +1383,8 @@ function getCurrentUserRequest() {
|
||||
num_outputs: numOutputsParallel,
|
||||
num_inference_steps: parseInt(numInferenceStepsField.value),
|
||||
guidance_scale: parseFloat(guidanceScaleField.value),
|
||||
width: parseInt(widthField.value),
|
||||
height: parseInt(heightField.value),
|
||||
width: width,
|
||||
height: height,
|
||||
// allow_nsfw: allowNSFWField.checked,
|
||||
vram_usage_level: vramUsageLevelField.value,
|
||||
sampler_name: samplerField.value,
|
||||
@ -1283,6 +1414,7 @@ function getCurrentUserRequest() {
|
||||
// }
|
||||
if (maskSetting.checked) {
|
||||
newTask.reqBody.mask = imageInpainter.getImg()
|
||||
newTask.reqBody.strict_mask_border = strictMaskBorderField.checked
|
||||
}
|
||||
newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked
|
||||
if (!testDiffusers.checked) {
|
||||
@ -1323,6 +1455,34 @@ function getCurrentUserRequest() {
|
||||
newTask.reqBody.lora_alpha = modelStrengths
|
||||
}
|
||||
}
|
||||
if (testDiffusers.checked && document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
|
||||
// TRT is installed
|
||||
newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked
|
||||
let trtBuildConfig = {
|
||||
batch_size_range: [
|
||||
parseInt(document.querySelector("#trt-build-min-batch").value),
|
||||
parseInt(document.querySelector("#trt-build-max-batch").value),
|
||||
],
|
||||
dimensions_range: [],
|
||||
}
|
||||
|
||||
let sizes = [512, 768, 1024, 1280, 1536]
|
||||
sizes.forEach((i) => {
|
||||
let el = document.querySelector("#trt-build-res-" + i)
|
||||
if (el.checked) {
|
||||
trtBuildConfig["dimensions_range"].push([i, i + 256])
|
||||
}
|
||||
})
|
||||
newTask.reqBody.trt_build_config = trtBuildConfig
|
||||
}
|
||||
if (controlnetModelField.value !== "" && IMAGE_REGEX.test(controlImagePreview.src)) {
|
||||
newTask.reqBody.use_controlnet_model = controlnetModelField.value
|
||||
newTask.reqBody.control_image = controlImagePreview.src
|
||||
if (controlImageFilterField.value !== "") {
|
||||
newTask.reqBody.control_filter_to_apply = controlImageFilterField.value
|
||||
}
|
||||
}
|
||||
|
||||
return newTask
|
||||
}
|
||||
|
||||
@ -1728,6 +1888,51 @@ function onFixFaceModelChange() {
|
||||
gfpganModelField.addEventListener("change", onFixFaceModelChange)
|
||||
onFixFaceModelChange()
|
||||
|
||||
function onControlnetModelChange() {
|
||||
let configBox = document.querySelector("#controlnet_config")
|
||||
if (IMAGE_REGEX.test(controlImagePreview.src)) {
|
||||
configBox.classList.remove("displayNone")
|
||||
controlImageContainer.classList.remove("displayNone")
|
||||
} else {
|
||||
configBox.classList.add("displayNone")
|
||||
controlImageContainer.classList.add("displayNone")
|
||||
}
|
||||
}
|
||||
controlImagePreview.addEventListener("load", onControlnetModelChange)
|
||||
controlImagePreview.addEventListener("unload", onControlnetModelChange)
|
||||
onControlnetModelChange()
|
||||
|
||||
function onControlImageFilterChange() {
|
||||
let filterId = controlImageFilterField.value
|
||||
if (filterId.includes("openpose")) {
|
||||
controlnetModelField.value = "control_v11p_sd15_openpose"
|
||||
} else if (filterId === "canny") {
|
||||
controlnetModelField.value = "control_v11p_sd15_canny"
|
||||
} else if (filterId === "mlsd") {
|
||||
controlnetModelField.value = "control_v11p_sd15_mlsd"
|
||||
} else if (filterId === "mlsd") {
|
||||
controlnetModelField.value = "control_v11p_sd15_mlsd"
|
||||
} else if (filterId.includes("scribble")) {
|
||||
controlnetModelField.value = "control_v11p_sd15_scribble"
|
||||
} else if (filterId.includes("softedge")) {
|
||||
controlnetModelField.value = "control_v11p_sd15_softedge"
|
||||
} else if (filterId === "normal_bae") {
|
||||
controlnetModelField.value = "control_v11p_sd15_normalbae"
|
||||
} else if (filterId.includes("depth")) {
|
||||
controlnetModelField.value = "control_v11f1p_sd15_depth"
|
||||
} else if (filterId === "lineart_anime") {
|
||||
controlnetModelField.value = "control_v11p_sd15s2_lineart_anime"
|
||||
} else if (filterId.includes("lineart")) {
|
||||
controlnetModelField.value = "control_v11p_sd15_lineart"
|
||||
} else if (filterId === "shuffle") {
|
||||
controlnetModelField.value = "control_v11e_sd15_shuffle"
|
||||
} else if (filterId === "segment") {
|
||||
controlnetModelField.value = "control_v11p_sd15_seg"
|
||||
}
|
||||
}
|
||||
controlImageFilterField.addEventListener("change", onControlImageFilterChange)
|
||||
onControlImageFilterChange()
|
||||
|
||||
upscaleModelField.disabled = !useUpscalingField.checked
|
||||
upscaleAmountField.disabled = !useUpscalingField.checked
|
||||
useUpscalingField.addEventListener("change", function(e) {
|
||||
@ -1957,6 +2162,7 @@ function checkRandomSeed() {
|
||||
randomSeedField.addEventListener("input", checkRandomSeed)
|
||||
checkRandomSeed()
|
||||
|
||||
// warning: the core plugin `image-editor-improvements.js:172` replaces loadImg2ImgFromFile() with a custom version
|
||||
function loadImg2ImgFromFile() {
|
||||
if (initImageSelector.files.length === 0) {
|
||||
return
|
||||
@ -1983,6 +2189,7 @@ function img2imgLoad() {
|
||||
}
|
||||
initImagePreviewContainer.classList.add("has-image")
|
||||
colorCorrectionSetting.style.display = ""
|
||||
strictMaskBorderSetting.style.display = maskSetting.checked ? "" : "none"
|
||||
|
||||
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
|
||||
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
|
||||
@ -2000,6 +2207,7 @@ function img2imgUnload() {
|
||||
}
|
||||
initImagePreviewContainer.classList.remove("has-image")
|
||||
colorCorrectionSetting.style.display = "none"
|
||||
strictMaskBorderSetting.style.display = "none"
|
||||
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
|
||||
}
|
||||
initImagePreview.addEventListener("load", img2imgLoad)
|
||||
@ -2008,11 +2216,55 @@ initImageClearBtn.addEventListener("click", img2imgUnload)
|
||||
maskSetting.addEventListener("click", function() {
|
||||
onDimensionChange()
|
||||
})
|
||||
maskSetting.addEventListener("change", function() {
|
||||
strictMaskBorderSetting.style.display = this.checked ? "" : "none"
|
||||
})
|
||||
|
||||
promptsFromFileBtn.addEventListener("click", function() {
|
||||
promptsFromFileSelector.click()
|
||||
})
|
||||
|
||||
function loadControlnetImageFromFile() {
|
||||
if (controlImageSelector.files.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
let reader = new FileReader()
|
||||
let file = controlImageSelector.files[0]
|
||||
|
||||
reader.addEventListener("load", function(event) {
|
||||
controlImagePreview.src = reader.result
|
||||
})
|
||||
|
||||
if (file) {
|
||||
reader.readAsDataURL(file)
|
||||
}
|
||||
}
|
||||
controlImageSelector.addEventListener("change", loadControlnetImageFromFile)
|
||||
|
||||
function controlImageLoad() {
|
||||
let w = controlImagePreview.naturalWidth
|
||||
let h = controlImagePreview.naturalHeight
|
||||
w = w - (w % 8)
|
||||
h = h - (h % 8)
|
||||
|
||||
addImageSizeOption(w)
|
||||
addImageSizeOption(h)
|
||||
|
||||
widthField.value = w
|
||||
heightField.value = h
|
||||
widthField.dispatchEvent(new Event("change"))
|
||||
heightField.dispatchEvent(new Event("change"))
|
||||
}
|
||||
controlImagePreview.addEventListener("load", controlImageLoad)
|
||||
|
||||
function controlImageUnload() {
|
||||
controlImageSelector.value = null
|
||||
controlImagePreview.src = ""
|
||||
controlImagePreview.dispatchEvent(new Event("unload"))
|
||||
}
|
||||
controlImageClearBtn.addEventListener("click", controlImageUnload)
|
||||
|
||||
promptsFromFileSelector.addEventListener("change", async function() {
|
||||
if (promptsFromFileSelector.files.length === 0) {
|
||||
return
|
||||
@ -2123,6 +2375,11 @@ resumeBtn.addEventListener("click", function() {
|
||||
document.body.classList.remove("wait-pause")
|
||||
})
|
||||
|
||||
function onPing(event) {
|
||||
tunnelUpdate(event)
|
||||
packagesUpdate(event)
|
||||
}
|
||||
|
||||
function tunnelUpdate(event) {
|
||||
if ("cloudflare" in event) {
|
||||
document.getElementById("cloudflare-off").classList.add("displayNone")
|
||||
@ -2136,6 +2393,42 @@ function tunnelUpdate(event) {
|
||||
}
|
||||
}
|
||||
|
||||
let trtSettingsForced = false
|
||||
|
||||
function packagesUpdate(event) {
|
||||
let trtBtn = document.getElementById("toggle-tensorrt-install")
|
||||
let trtInstalled = "packages_installed" in event && "tensorrt" in event["packages_installed"]
|
||||
|
||||
if ("packages_installing" in event && event["packages_installing"].includes("tensorrt")) {
|
||||
trtBtn.innerHTML = "Installing.."
|
||||
trtBtn.disabled = true
|
||||
} else {
|
||||
trtBtn.innerHTML = trtInstalled ? "Uninstall" : "Install"
|
||||
trtBtn.disabled = false
|
||||
}
|
||||
|
||||
if (document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
|
||||
document.querySelector("#enable_trt_config").classList.remove("displayNone")
|
||||
document.querySelector("#trt-build-config").classList.remove("displayNone")
|
||||
|
||||
if (!trtSettingsForced) {
|
||||
// settings for demo
|
||||
promptField.value = "Dragons fighting with a knight, castle, war scene, fantasy, cartoon, flames, HD"
|
||||
seedField.value = 3187947173
|
||||
widthField.value = 1024
|
||||
heightField.value = 768
|
||||
randomSeedField.checked = false
|
||||
seedField.disabled = false
|
||||
stableDiffusionModelField.value = "sd-v1-4"
|
||||
|
||||
// numOutputsParallelField.classList.add("displayNone")
|
||||
// document.querySelector("#num_outputs_parallel_label").classList.add("displayNone")
|
||||
|
||||
trtSettingsForced = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", async function() {
|
||||
let command = "stop"
|
||||
if (document.getElementById("toggle-cloudflare-tunnel").innerHTML == "Start") {
|
||||
@ -2155,6 +2448,63 @@ document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", as
|
||||
console.log(`Cloudflare tunnel ${command} result:`, res)
|
||||
})
|
||||
|
||||
document.getElementById("toggle-tensorrt-install").addEventListener("click", function(e) {
|
||||
if (this.disabled === true) {
|
||||
return
|
||||
}
|
||||
|
||||
let command = this.innerHTML.toLowerCase()
|
||||
let self = this
|
||||
|
||||
shiftOrConfirm(
|
||||
e,
|
||||
"Are you sure you want to " + command + " TensorRT?",
|
||||
async function() {
|
||||
showToast(`TensorRT ${command} started. Please wait.`)
|
||||
|
||||
self.disabled = true
|
||||
|
||||
if (command === "install") {
|
||||
self.innerHTML = "Installing.."
|
||||
} else if (command === "uninstall") {
|
||||
self.innerHTML = "Uninstalling.."
|
||||
}
|
||||
|
||||
if (command === "installing..") {
|
||||
alert("Already installing TensorRT!")
|
||||
return
|
||||
}
|
||||
if (command !== "install" && command !== "uninstall") {
|
||||
return
|
||||
}
|
||||
|
||||
let res = await fetch("/package/tensorrt", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
command: command,
|
||||
}),
|
||||
})
|
||||
res = await res.json()
|
||||
|
||||
self.disabled = false
|
||||
|
||||
if (res.status === "OK") {
|
||||
alert("TensorRT " + command + "ed successfully!")
|
||||
self.innerHTML = command === "install" ? "Uninstall" : "Install"
|
||||
} else if (res.status_code === 500) {
|
||||
alert("TensorselfRT failed to " + command + ": " + res.detail)
|
||||
self.innerHTML = command === "install" ? "Install" : "Uninstall"
|
||||
}
|
||||
|
||||
console.log(`Package ${command} result:`, res)
|
||||
},
|
||||
false
|
||||
)
|
||||
})
|
||||
|
||||
/* Embeddings */
|
||||
|
||||
function updateEmbeddingsList(filter = "") {
|
||||
@ -2171,7 +2521,10 @@ function updateEmbeddingsList(filter = "") {
|
||||
} else {
|
||||
let subdir = html(m[1], prefix + m[0] + "/", filter)
|
||||
if (subdir != "") {
|
||||
folders += `<div class="embedding-category"><h4 class="collapsible">${prefix}${m[0]}</h4><div class="collapsible-content">` + subdir + '</div></div>'
|
||||
folders +=
|
||||
`<div class="embedding-category"><h4 class="collapsible">${prefix}${m[0]}</h4><div class="collapsible-content">` +
|
||||
subdir +
|
||||
"</div></div>"
|
||||
}
|
||||
}
|
||||
})
|
||||
@ -2293,7 +2646,6 @@ embeddingsCollapsiblesBtn.addEventListener("click", (e) => {
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
if (testDiffusers.checked) {
|
||||
document.getElementById("embeddings-container").classList.remove("displayNone")
|
||||
}
|
||||
@ -2390,3 +2742,172 @@ createLoraEntries()
|
||||
// }
|
||||
|
||||
// document.querySelectorAll("input[type=number]").forEach(showSpinnerOnlyOnHover)
|
||||
|
||||
////////////////////////////// Image Size Widget //////////////////////////////////////////
|
||||
|
||||
function roundToMultiple(number, n) {
|
||||
if (n == "") {
|
||||
n = 1
|
||||
}
|
||||
return Math.round(number / n) * n
|
||||
}
|
||||
|
||||
function addImageSizeOption(size) {
|
||||
let sizes = Object.values(widthField.options).map((o) => o.value)
|
||||
if (!sizes.includes(String(size))) {
|
||||
sizes.push(String(size))
|
||||
sizes.sort((a, b) => Number(a) - Number(b))
|
||||
|
||||
let option = document.createElement("option")
|
||||
option.value = size
|
||||
option.text = `${size}`
|
||||
|
||||
widthField.add(option, sizes.indexOf(String(size)))
|
||||
heightField.add(option.cloneNode(true), sizes.indexOf(String(size)))
|
||||
}
|
||||
}
|
||||
|
||||
function setImageWidthHeight(w, h) {
|
||||
let step = customWidthField.step
|
||||
w = roundToMultiple(w, step)
|
||||
h = roundToMultiple(h, step)
|
||||
|
||||
addImageSizeOption(w)
|
||||
addImageSizeOption(h)
|
||||
|
||||
widthField.value = w
|
||||
heightField.value = h
|
||||
widthField.dispatchEvent(new Event("change"))
|
||||
heightField.dispatchEvent(new Event("change"))
|
||||
}
|
||||
|
||||
function enlargeImageSize(factor) {
|
||||
let step = customWidthField.step
|
||||
|
||||
let w = roundToMultiple(widthField.value * factor, step)
|
||||
let h = roundToMultiple(heightField.value * factor, step)
|
||||
customWidthField.value = w
|
||||
customHeightField.value = h
|
||||
}
|
||||
|
||||
let recentResolutionsValues = []
|
||||
|
||||
;(function() {
|
||||
///// Init resolutions dropdown
|
||||
function makeResolutionButtons() {
|
||||
recentResolutionList.innerHTML = ""
|
||||
recentResolutionsValues.forEach((el) => {
|
||||
let button = document.createElement("button")
|
||||
button.classList.add("tertiaryButton")
|
||||
button.style.width = "8em"
|
||||
button.innerHTML = `${el.w}×${el.h}`
|
||||
button.addEventListener("click", () => {
|
||||
customWidthField.value = el.w
|
||||
customHeightField.value = el.h
|
||||
hidePopup()
|
||||
})
|
||||
recentResolutionList.appendChild(button)
|
||||
recentResolutionList.appendChild(document.createElement("br"))
|
||||
})
|
||||
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
|
||||
}
|
||||
|
||||
enlarge15Button.addEventListener("click", () => {
|
||||
enlargeImageSize(1.5)
|
||||
hidePopup()
|
||||
})
|
||||
|
||||
enlarge2Button.addEventListener("click", () => {
|
||||
enlargeImageSize(2)
|
||||
hidePopup()
|
||||
})
|
||||
|
||||
enlarge3Button.addEventListener("click", () => {
|
||||
enlargeImageSize(3)
|
||||
hidePopup()
|
||||
})
|
||||
|
||||
customWidthField.addEventListener("change", () => {
|
||||
let w = customWidthField.value
|
||||
customWidthField.value = roundToMultiple(w, customWidthField.step)
|
||||
if (w != customWidthField.value) {
|
||||
showToast(`Rounded width to the closest multiple of ${customWidthField.step}.`)
|
||||
}
|
||||
})
|
||||
|
||||
customHeightField.addEventListener("change", () => {
|
||||
let h = customHeightField.value
|
||||
customHeightField.value = roundToMultiple(h, customHeightField.step)
|
||||
if (h != customHeightField.value) {
|
||||
showToast(`Rounded height to the closest multiple of ${customHeightField.step}.`)
|
||||
}
|
||||
})
|
||||
|
||||
makeImageBtn.addEventListener("click", () => {
|
||||
let w = widthField.value
|
||||
let h = heightField.value
|
||||
|
||||
recentResolutionsValues = recentResolutionsValues.filter((el) => el.w != w || el.h != h)
|
||||
recentResolutionsValues.unshift({ w: w, h: h })
|
||||
recentResolutionsValues = recentResolutionsValues.slice(0, 8)
|
||||
|
||||
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
|
||||
makeResolutionButtons()
|
||||
})
|
||||
|
||||
let _jsonstring = localStorage.recentResolutionsValues
|
||||
if (_jsonstring == undefined) {
|
||||
recentResolutionsValues = [
|
||||
{ w: 512, h: 512 },
|
||||
{ w: 640, h: 448 },
|
||||
{ w: 448, h: 640 },
|
||||
{ w: 512, h: 768 },
|
||||
{ w: 768, h: 512 },
|
||||
{ w: 1024, h: 768 },
|
||||
{ w: 768, h: 1024 },
|
||||
]
|
||||
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
|
||||
} else {
|
||||
recentResolutionsValues = JSON.parse(localStorage.recentResolutionsValues)
|
||||
}
|
||||
makeResolutionButtons()
|
||||
|
||||
recentResolutionsValues.forEach((val) => {
|
||||
addImageSizeOption(val.w)
|
||||
addImageSizeOption(val.h)
|
||||
})
|
||||
|
||||
function processClick(e) {
|
||||
if (!recentResolutionsPopup.contains(e.target)) {
|
||||
hidePopup()
|
||||
}
|
||||
}
|
||||
|
||||
function showPopup() {
|
||||
customWidthField.value = widthField.value
|
||||
customHeightField.value = heightField.value
|
||||
recentResolutionsPopup.classList.remove("displayNone")
|
||||
document.addEventListener("click", processClick)
|
||||
}
|
||||
|
||||
function hidePopup() {
|
||||
recentResolutionsPopup.classList.add("displayNone")
|
||||
setImageWidthHeight(customWidthField.value, customHeightField.value)
|
||||
document.removeEventListener("click", processClick)
|
||||
}
|
||||
|
||||
recentResolutionsButton.addEventListener("click", (event) => {
|
||||
if (recentResolutionsPopup.classList.contains("displayNone")) {
|
||||
showPopup()
|
||||
event.stopPropagation()
|
||||
} else {
|
||||
hidePopup()
|
||||
}
|
||||
})
|
||||
|
||||
swapWidthHeightButton.addEventListener("click", (event) => {
|
||||
let temp = widthField.value
|
||||
widthField.value = heightField.value
|
||||
heightField.value = temp
|
||||
})
|
||||
})()
|
||||
|
@ -16,6 +16,7 @@ var ParameterType = {
|
||||
*/
|
||||
let parametersTable = document.querySelector("#system-settings-table")
|
||||
let networkParametersTable = document.querySelector("#system-settings-network-table")
|
||||
let installExtrasTable = document.querySelector("#system-settings-install-extras-table")
|
||||
|
||||
/**
|
||||
* JSDoc style
|
||||
@ -241,6 +242,29 @@ var PARAMETERS = [
|
||||
render: () => '<button id="toggle-cloudflare-tunnel" class="primaryButton">Start</button>',
|
||||
table: networkParametersTable,
|
||||
},
|
||||
{
|
||||
id: "nvidia_tensorrt",
|
||||
type: ParameterType.custom,
|
||||
label: "NVIDIA TensorRT",
|
||||
note: `Faster image generation by converting your Stable Diffusion models to the NVIDIA TensorRT format. You can choose the
|
||||
models to convert. Download size: approximately 2 GB.<br/><br/>
|
||||
<b>Early access version:</b> support for LoRA is still under development.
|
||||
<div id="trt-build-config" class="displayNone">
|
||||
<h3>Build Config:</h3>
|
||||
Batch size range:
|
||||
<label>Min:</label> <input id="trt-build-min-batch" type="number" min="1" value="1" style="width: 40pt" />
|
||||
<label>Max:</label> <input id="trt-build-max-batch" type="number" min="1" value="1" style="width: 40pt" /><br/><br/>
|
||||
<b>Build for resolutions</b>:<br/>
|
||||
<input id="trt-build-res-512" type="checkbox" value="1" /> 512x512 to 768x768<br/>
|
||||
<input id="trt-build-res-768" type="checkbox" value="1" checked /> 768x768 to 1024x1024<br/>
|
||||
<input id="trt-build-res-1024" type="checkbox" value="1" /> 1024x1024 to 1280x1280<br/>
|
||||
<input id="trt-build-res-1280" type="checkbox" value="1" /> 1280x1280 to 1536x1536<br/>
|
||||
<input id="trt-build-res-1536" type="checkbox" value="1" /> 1536x1536 to 1792x1792<br/>
|
||||
</div>`,
|
||||
icon: "fa-angles-up",
|
||||
render: () => '<button id="toggle-tensorrt-install" class="primaryButton">Install</button>',
|
||||
table: installExtrasTable,
|
||||
},
|
||||
]
|
||||
|
||||
function getParameterSettingsEntry(id) {
|
||||
@ -441,6 +465,8 @@ async function getAppConfig() {
|
||||
document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => {
|
||||
option.style.display = "none"
|
||||
})
|
||||
customWidthField.step = 64
|
||||
customHeightField.step = 64
|
||||
} else {
|
||||
document.querySelector("#lora_model_container").style.display = ""
|
||||
document.querySelector("#tiling_container").style.display = ""
|
||||
@ -451,6 +477,8 @@ async function getAppConfig() {
|
||||
document.querySelector("#clip_skip_config").classList.remove("displayNone")
|
||||
document.querySelector("#embeddings-button").classList.remove("displayNone")
|
||||
document.querySelector("#negative-embeddings-button").classList.remove("displayNone")
|
||||
customWidthField.step = 8
|
||||
customHeightField.step = 8
|
||||
}
|
||||
|
||||
console.log("get config status response", config)
|
||||
@ -582,6 +610,23 @@ function setDeviceInfo(devices) {
|
||||
systemInfoEl.querySelector("#system-info-cpu").innerText = cpu
|
||||
systemInfoEl.querySelector("#system-info-gpus-all").innerHTML = allGPUs.join("</br>")
|
||||
systemInfoEl.querySelector("#system-info-rendering-devices").innerHTML = activeGPUs.join("</br>")
|
||||
|
||||
// tensorRT
|
||||
if (devices.active && testDiffusers.checked && devices.enable_trt === true) {
|
||||
let nvidiaGPUs = Object.keys(devices.active).filter((d) => {
|
||||
let gpuName = devices.active[d].name
|
||||
gpuName = gpuName.toLowerCase()
|
||||
return (
|
||||
gpuName.includes("nvidia") ||
|
||||
gpuName.includes("geforce") ||
|
||||
gpuName.includes("quadro") ||
|
||||
gpuName.includes("tesla")
|
||||
)
|
||||
})
|
||||
if (nvidiaGPUs.length > 0) {
|
||||
document.querySelector("#install-extras-container").classList.remove("displayNone")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function setHostInfo(hosts) {
|
||||
@ -744,10 +789,10 @@ navigator.permissions.query({ name: "clipboard-write" }).then(function(result) {
|
||||
|
||||
document.addEventListener("system_info_update", (e) => setDeviceInfo(e.detail))
|
||||
|
||||
useBetaChannelField.addEventListener('change', (e) => {
|
||||
if (e.target.checked) {
|
||||
getParameterSettingsEntry("test_diffusers").classList.remove('displayNone')
|
||||
} else {
|
||||
getParameterSettingsEntry("test_diffusers").classList.add('displayNone')
|
||||
useBetaChannelField.addEventListener("change", (e) => {
|
||||
if (e.target.checked) {
|
||||
getParameterSettingsEntry("test_diffusers").classList.remove("displayNone")
|
||||
} else {
|
||||
getParameterSettingsEntry("test_diffusers").classList.add("displayNone")
|
||||
}
|
||||
})
|
||||
|
@ -552,17 +552,23 @@ class ModelDropdown {
|
||||
this.createModelNodeList(`${folderName || ""}/${childFolderName}`, childModels, false)
|
||||
)
|
||||
} else {
|
||||
let modelId = model
|
||||
let modelName = model
|
||||
if (typeof model === "object") {
|
||||
modelId = Object.keys(model)[0]
|
||||
modelName = model[modelId]
|
||||
}
|
||||
const classes = ["model-file"]
|
||||
if (isRootFolder) {
|
||||
classes.push("in-root-folder")
|
||||
}
|
||||
// Remove the leading slash from the model path
|
||||
const fullPath = folderName ? `${folderName.substring(1)}/${model}` : model
|
||||
const fullPath = folderName ? `${folderName.substring(1)}/${modelId}` : modelId
|
||||
modelsMap.set(
|
||||
model,
|
||||
modelId,
|
||||
createElement("li", { "data-path": fullPath }, classes, [
|
||||
createElement("i", undefined, ["fa-regular", "fa-file", "icon"]),
|
||||
model,
|
||||
modelName,
|
||||
])
|
||||
)
|
||||
}
|
||||
@ -643,22 +649,6 @@ async function getModels(scanForMalicious = true) {
|
||||
makeImageBtn.disabled = true
|
||||
}
|
||||
|
||||
/* This code should no longer be needed. Commenting out for now, will cleanup later.
|
||||
const sd_model_setting_key = "stable_diffusion_model"
|
||||
const vae_model_setting_key = "vae_model"
|
||||
const hypernetwork_model_key = "hypernetwork_model"
|
||||
|
||||
const stableDiffusionOptions = modelsOptions['stable-diffusion']
|
||||
const vaeOptions = modelsOptions['vae']
|
||||
const hypernetworkOptions = modelsOptions['hypernetwork']
|
||||
|
||||
// TODO: set default for model here too
|
||||
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
|
||||
if (getSetting(sd_model_setting_key) == '' || SETTINGS[sd_model_setting_key].value == '') {
|
||||
setSetting(sd_model_setting_key, stableDiffusionOptions[0])
|
||||
}
|
||||
*/
|
||||
|
||||
// notify ModelDropdown objects to refresh
|
||||
document.dispatchEvent(new Event("refreshModels"))
|
||||
} catch (e) {
|
||||
@ -667,4 +657,7 @@ async function getModels(scanForMalicious = true) {
|
||||
}
|
||||
|
||||
// reload models button
|
||||
document.querySelector("#reload-models").addEventListener("click", () => getModels())
|
||||
document.querySelector("#reload-models").addEventListener("click", (e) => {
|
||||
e.stopPropagation()
|
||||
getModels()
|
||||
})
|
||||
|
@ -109,8 +109,10 @@
|
||||
|
||||
imageObj.onload = function() {
|
||||
// Calculate the maximum cropped dimensions
|
||||
const maxCroppedWidth = Math.floor(this.width / 64) * 64;
|
||||
const maxCroppedHeight = Math.floor(this.height / 64) * 64;
|
||||
const step = customWidthField.step
|
||||
|
||||
const maxCroppedWidth = Math.floor(this.width / step) * step;
|
||||
const maxCroppedHeight = Math.floor(this.height / step) * step;
|
||||
|
||||
canvas.width = maxCroppedWidth;
|
||||
canvas.height = maxCroppedHeight;
|
||||
@ -122,35 +124,17 @@
|
||||
// Draw the image with centered coordinates
|
||||
context.drawImage(imageObj, x, y, this.width, this.height);
|
||||
|
||||
initImagePreview.src = canvas.toDataURL('image/png');
|
||||
let bestWidth = maxCroppedWidth - maxCroppedWidth % 8
|
||||
let bestHeight = maxCroppedHeight - maxCroppedHeight % 8
|
||||
|
||||
// Get the options from widthField and heightField
|
||||
const widthOptions = Array.from(widthField.options).map(option => parseInt(option.value));
|
||||
const heightOptions = Array.from(heightField.options).map(option => parseInt(option.value));
|
||||
|
||||
// Find the closest aspect ratio and closest to original dimensions
|
||||
let bestWidth = widthOptions[0];
|
||||
let bestHeight = heightOptions[0];
|
||||
let minDifference = Math.abs(maxCroppedWidth / maxCroppedHeight - bestWidth / bestHeight);
|
||||
let minDistance = Math.abs(maxCroppedWidth - bestWidth) + Math.abs(maxCroppedHeight - bestHeight);
|
||||
|
||||
for (const width of widthOptions) {
|
||||
for (const height of heightOptions) {
|
||||
const difference = Math.abs(maxCroppedWidth / maxCroppedHeight - width / height);
|
||||
const distance = Math.abs(maxCroppedWidth - width) + Math.abs(maxCroppedHeight - height);
|
||||
|
||||
if (difference < minDifference || (difference === minDifference && distance < minDistance)) {
|
||||
minDifference = difference;
|
||||
minDistance = distance;
|
||||
bestWidth = width;
|
||||
bestHeight = height;
|
||||
}
|
||||
}
|
||||
}
|
||||
addImageSizeOption(bestWidth)
|
||||
addImageSizeOption(bestHeight)
|
||||
|
||||
// Set the width and height to the closest aspect ratio and closest to original dimensions
|
||||
widthField.value = bestWidth;
|
||||
heightField.value = bestHeight;
|
||||
|
||||
initImagePreview.src = canvas.toDataURL('image/png');
|
||||
};
|
||||
|
||||
function handlePaste(e) {
|
||||
|
114
ui/plugins/ui/lora-prompt-parser.plugin.js
Normal file
114
ui/plugins/ui/lora-prompt-parser.plugin.js
Normal file
@ -0,0 +1,114 @@
|
||||
/*
|
||||
LoRA Prompt Parser 1.0
|
||||
by Patrice
|
||||
|
||||
Copying and pasting a prompt with a LoRA tag will automatically select the corresponding option in the Easy Diffusion dropdown and remove the LoRA tag from the prompt. The LoRA must be already available in the corresponding Easy Diffusion dropdown (this is not a LoRA downloader).
|
||||
*/
|
||||
(function() {
|
||||
"use strict"
|
||||
|
||||
promptField.addEventListener('input', function(e) {
|
||||
const { LoRA, prompt } = extractLoraTags(e.target.value);
|
||||
//console.log('e.target: ' + JSON.stringify(LoRA));
|
||||
|
||||
if (LoRA !== null && LoRA.length > 0) {
|
||||
promptField.value = prompt.replace(/,+$/, ''); // remove any trailing ,
|
||||
|
||||
if (testDiffusers?.checked === false) {
|
||||
showToast("LoRA's are only supported with diffusers. Just stripping the LoRA tag from the prompt.")
|
||||
}
|
||||
}
|
||||
|
||||
if (LoRA !== null && LoRA.length > 0 && testDiffusers?.checked) {
|
||||
for (let i = 0; i < LoRA.length; i++) {
|
||||
//if (loraModelField.value !== LoRA[0].lora_model) {
|
||||
// Set the new LoRA value
|
||||
//console.log("Loading info");
|
||||
//console.log(LoRA[0].lora_model_0);
|
||||
//console.log(JSON.stringify(LoRa));
|
||||
|
||||
let lora = `lora_model_${i}`;
|
||||
let alpha = `lora_alpha_${i}`;
|
||||
let loramodel = document.getElementById(lora);
|
||||
let alphavalue = document.getElementById(alpha);
|
||||
loramodel.setAttribute("data-path", LoRA[i].lora_model_0);
|
||||
loramodel.value = LoRA[i].lora_model_0;
|
||||
alphavalue.value = LoRA[i].lora_alpha_0;
|
||||
if (i != LoRA.length - 1)
|
||||
createLoraEntry();
|
||||
}
|
||||
//loraAlphaSlider.value = loraAlphaField.value * 100;
|
||||
//TBD.value = LoRA[0].blockweights; // block weights not supported by ED at this time
|
||||
//}
|
||||
showToast("Prompt successfully processed", LoRA[0].lora_model_0);
|
||||
//console.log('LoRa: ' + LoRA[0].lora_model_0);
|
||||
//showToast("Prompt successfully processed", lora_model_0.value);
|
||||
|
||||
}
|
||||
|
||||
//promptField.dispatchEvent(new Event('change'));
|
||||
});
|
||||
|
||||
function isModelAvailable(array, searchString) {
|
||||
const foundItem = array.find(function(item) {
|
||||
item = item.toString().toLowerCase();
|
||||
return item === searchString.toLowerCase()
|
||||
});
|
||||
|
||||
return foundItem || "";
|
||||
}
|
||||
|
||||
// extract LoRA tags from strings
|
||||
function extractLoraTags(prompt) {
|
||||
// Define the regular expression for the tags
|
||||
const regex = /<(?:lora|lyco):([^:>]+)(?::([^:>]*))?(?::([^:>]*))?>/gi
|
||||
|
||||
// Initialize an array to hold the matches
|
||||
let matches = []
|
||||
|
||||
// Iterate over the string, finding matches
|
||||
for (const match of prompt.matchAll(regex)) {
|
||||
const modelFileName = isModelAvailable(modelsCache.options.lora, match[1].trim())
|
||||
if (modelFileName !== "") {
|
||||
// Initialize an object to hold a match
|
||||
let loraTag = {
|
||||
lora_model_0: modelFileName,
|
||||
}
|
||||
//console.log("Model:" + modelFileName);
|
||||
|
||||
// If weight is provided, add it to the loraTag object
|
||||
if (match[2] !== undefined && match[2] !== '') {
|
||||
loraTag.lora_alpha_0 = parseFloat(match[2].trim())
|
||||
}
|
||||
else
|
||||
{
|
||||
loraTag.lora_alpha_0 = 0.5
|
||||
}
|
||||
|
||||
|
||||
// If blockweights are provided, add them to the loraTag object
|
||||
if (match[3] !== undefined && match[3] !== '') {
|
||||
loraTag.blockweights = match[3].trim()
|
||||
}
|
||||
|
||||
// Add the loraTag object to the array of matches
|
||||
matches.push(loraTag);
|
||||
//console.log(JSON.stringify(matches));
|
||||
}
|
||||
else
|
||||
{
|
||||
showToast("LoRA not found: " + match[1].trim(), 5000, true)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up the prompt string, e.g. from "apple, banana, <lora:...>, orange, <lora:...> , pear <lora:...>, <lora:...>" to "apple, banana, orange, pear"
|
||||
let cleanedPrompt = prompt.replace(regex, '').replace(/(\s*,\s*(?=\s*,|$))|(^\s*,\s*)|\s+/g, ' ').trim();
|
||||
//console.log('Matches: ' + JSON.stringify(matches));
|
||||
|
||||
// Return the array of matches and cleaned prompt string
|
||||
return {
|
||||
LoRA: matches,
|
||||
prompt: cleanedPrompt
|
||||
}
|
||||
}
|
||||
})()
|
Reference in New Issue
Block a user