Merge branch 'beta' into bucketlite

This commit is contained in:
JeLuF 2023-07-30 23:41:39 +02:00
commit 518df4bd3e
23 changed files with 1506 additions and 372 deletions

View File

@ -22,6 +22,14 @@
Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed. Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed.
### Detailed changelog ### Detailed changelog
* 2.5.47 - 30 Jul 2023 - An option to use `Strict Mask Border` while inpainting, to avoid touching areas outside the mask. But this might show a slight outline of the mask, which you will have to touch up separately.
* 2.5.47 - 29 Jul 2023 - (beta-only) Fix long prompts with SDXL.
* 2.5.47 - 29 Jul 2023 - (beta-only) Fix red dots in some SDXL images.
* 2.5.47 - 29 Jul 2023 - Significantly faster `Fix Faces` and `Upscale` buttons (on the image). They no longer need to generate the image from scratch, instead they just upscale/fix the generated image in-place.
* 2.5.47 - 28 Jul 2023 - Lots of internal code reorganization, in preparation for supporting Controlnets. No user-facing changes.
* 2.5.46 - 27 Jul 2023 - (beta-only) Full support for SD-XL models (base and refiner)!
* 2.5.45 - 24 Jul 2023 - (beta-only) Hide the samplers that won't be supported in the new diffusers version.
* 2.5.45 - 22 Jul 2023 - (beta-only) Fix the recently-broken inpainting models.
* 2.5.45 - 16 Jul 2023 - (beta-only) Fix the image quality of LoRAs, which had degraded in v2.5.44. * 2.5.45 - 16 Jul 2023 - (beta-only) Fix the image quality of LoRAs, which had degraded in v2.5.44.
* 2.5.44 - 15 Jul 2023 - (beta-only) Support for multiple LoRA files. * 2.5.44 - 15 Jul 2023 - (beta-only) Support for multiple LoRA files.
* 2.5.43 - 9 Jul 2023 - (beta-only) Support for loading Textual Inversion embeddings. You can find the option in the Image Settings panel. Thanks @JeLuf. * 2.5.43 - 9 Jul 2023 - (beta-only) Support for loading Textual Inversion embeddings. You can find the option in the Image Settings panel. Thanks @JeLuf.

View File

@ -18,7 +18,7 @@ os_name = platform.system()
modules_to_check = { modules_to_check = {
"torch": ("1.11.0", "1.13.1", "2.0.0"), "torch": ("1.11.0", "1.13.1", "2.0.0"),
"torchvision": ("0.12.0", "0.14.1", "0.15.1"), "torchvision": ("0.12.0", "0.14.1", "0.15.1"),
"sdkit": "1.0.134", "sdkit": "1.0.151",
"stable-diffusion-sdkit": "2.1.4", "stable-diffusion-sdkit": "2.1.4",
"rich": "12.6.0", "rich": "12.6.0",
"uvicorn": "0.19.0", "uvicorn": "0.19.0",

View File

@ -32,6 +32,8 @@ logging.basicConfig(
SD_DIR = os.getcwd() SD_DIR = os.getcwd()
ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, ".."))
SD_UI_DIR = os.getenv("SD_UI_PATH", None) SD_UI_DIR = os.getenv("SD_UI_PATH", None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
@ -103,6 +105,7 @@ def init_render_threads():
update_render_threads() update_render_threads()
def getConfig(default_val=APP_CONFIG_DEFAULTS): def getConfig(default_val=APP_CONFIG_DEFAULTS):
config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml") config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml")
@ -112,9 +115,9 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
shutil.move(config_legacy_yaml, config_yaml_path) shutil.move(config_legacy_yaml, config_yaml_path)
def set_config_on_startup(config: dict): def set_config_on_startup(config: dict):
if (getConfig.__config_on_startup is None): if getConfig.__test_diffusers_on_startup is None:
getConfig.__config_on_startup = copy.deepcopy(config) getConfig.__test_diffusers_on_startup = config.get("test_diffusers", False)
config["config_on_startup"] = getConfig.__config_on_startup config["config_on_startup"] = {"test_diffusers": getConfig.__test_diffusers_on_startup}
if os.path.isfile(config_yaml_path): if os.path.isfile(config_yaml_path):
try: try:
@ -161,7 +164,8 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
set_config_on_startup(default_val) set_config_on_startup(default_val)
return default_val return default_val
getConfig.__config_on_startup = None
getConfig.__test_diffusers_on_startup = None
def setConfig(config): def setConfig(config):
@ -182,6 +186,9 @@ def setConfig(config):
config = commented_config config = commented_config
yaml.indent(mapping=2, sequence=4, offset=2) yaml.indent(mapping=2, sequence=4, offset=2)
if "config_on_startup" in config:
del config["config_on_startup"]
try: try:
f = open(config_yaml_path + ".tmp", "w", encoding="utf-8") f = open(config_yaml_path + ".tmp", "w", encoding="utf-8")
yaml.dump(config, f) yaml.dump(config, f)

View File

@ -5,7 +5,7 @@ import traceback
from typing import Union from typing import Union
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import TaskData from easydiffusion.types import ModelsData
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit import Context from sdkit import Context
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
@ -57,7 +57,9 @@ def init():
def load_default_models(context: Context): def load_default_models(context: Context):
set_vram_optimizations(context) from easydiffusion import runtime
runtime.set_vram_optimizations(context)
config = app.getConfig() config = app.getConfig()
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings") context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
@ -138,43 +140,32 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?") raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
def reload_models_if_necessary(context: Context, task_data: TaskData): def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else ""
upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else ""
model_paths_in_req = {
"stable-diffusion": task_data.use_stable_diffusion_model,
"vae": task_data.use_vae_model,
"hypernetwork": task_data.use_hypernetwork_model,
"codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None,
"gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None,
"realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None,
"latent_upscaler": True if "latent_upscaler" in upscale_lower else None,
"nsfw_checker": True if task_data.block_nsfw else None,
"lora": task_data.use_lora_model,
}
models_to_reload = { models_to_reload = {
model_type: path model_type: path
for model_type, path in model_paths_in_req.items() for model_type, path in models_data.model_paths.items()
if context.model_paths.get(model_type) != path if context.model_paths.get(model_type) != path
} }
if task_data.codeformer_upscale_faces: if models_data.model_paths.get("codeformer"):
if "realesrgan" not in models_to_reload and "realesrgan" not in context.models: if "realesrgan" not in models_to_reload and "realesrgan" not in context.models:
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None: elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None:
del models_to_reload["realesrgan"] # don't unload realesrgan del models_to_reload["realesrgan"] # don't unload realesrgan
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD for model_type in models_to_force_reload:
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"] if model_type not in models_data.model_paths:
continue
models_to_reload[model_type] = models_data.model_paths[model_type]
for model_type, model_path_in_req in models_to_reload.items(): for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn = unload_model if context.model_paths[model_type] is None else load_model
extra_params = models_data.model_params.get(model_type, {})
try: try:
action_fn(context, model_type, scan_model=False) # we've scanned them already action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already
if model_type in context.model_load_errors: if model_type in context.model_load_errors:
del context.model_load_errors[model_type] del context.model_load_errors[model_type]
except Exception as e: except Exception as e:
@ -183,24 +174,15 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
def resolve_model_paths(task_data: TaskData): def resolve_model_paths(models_data: ModelsData):
task_data.use_stable_diffusion_model = resolve_model_to_use( model_paths = models_data.model_paths
task_data.use_stable_diffusion_model, model_type="stable-diffusion" for model_type in model_paths:
) if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") continue
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") if model_type == "codeformer":
task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora")
if task_data.use_face_correction:
if "gfpgan" in task_data.use_face_correction.lower():
model_type = "gfpgan"
elif "codeformer" in task_data.use_face_correction.lower():
model_type = "codeformer"
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type) model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
def fail_if_models_did_not_load(context: Context): def fail_if_models_did_not_load(context: Context):
@ -235,17 +217,6 @@ def download_if_necessary(model_type: str, file_name: str, model_id: str):
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR) download_model(model_type, model_id, download_base_dir=app.MODELS_DIR)
def set_vram_optimizations(context: Context):
config = app.getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level
return True
return False
def migrate_legacy_model_location(): def migrate_legacy_model_location():
'Move the models inside the legacy "stable-diffusion" folder, to their respective folders' 'Move the models inside the legacy "stable-diffusion" folder, to their respective folders'
@ -266,16 +237,6 @@ def any_model_exists(model_type: str) -> bool:
return False return False
def set_clip_skip(context: Context, task_data: TaskData):
clip_skip = task_data.clip_skip
if clip_skip != context.clip_skip:
context.clip_skip = clip_skip
return True
return False
def make_model_folders(): def make_model_folders():
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type) model_dir_path = os.path.join(app.MODELS_DIR, model_type)

View File

@ -0,0 +1,98 @@
import sys
import os
import platform
from importlib.metadata import version as pkg_version
from sdkit.utils import log
from easydiffusion import app
# future home of scripts/check_modules.py
manifest = {
"tensorrt": {
"install": [
"nvidia-cudnn --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt-libs --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
"tensorrt --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com",
],
"uninstall": ["tensorrt"],
# TODO also uninstall tensorrt-libs and nvidia-cudnn, but do it upon restarting (avoid 'file in use' error)
}
}
installing = []
# remove this once TRT releases on pypi
if platform.system() == "Windows":
trt_dir = os.path.join(app.ROOT_DIR, "tensorrt")
if os.path.exists(trt_dir):
files = os.listdir(trt_dir)
packages = manifest["tensorrt"]["install"]
packages = tuple(p.replace("-", "_") for p in packages)
wheels = []
for p in packages:
p = p.split(" ")[0]
f = next((f for f in files if f.startswith(p) and f.endswith((".whl", ".tar.gz"))), None)
if f:
wheels.append(os.path.join(trt_dir, f))
manifest["tensorrt"]["install"] = wheels
def get_installed_packages() -> list:
return {module_name: version(module_name) for module_name in manifest if is_installed(module_name)}
def is_installed(module_name) -> bool:
return version(module_name) is not None
def install(module_name):
if is_installed(module_name):
log.info(f"{module_name} has already been installed!")
return
if module_name in installing:
log.info(f"{module_name} is already installing!")
return
if module_name not in manifest:
raise RuntimeError(f"Can't install unknown package: {module_name}!")
commands = manifest[module_name]["install"]
commands = [f"python -m pip install --upgrade {cmd}" for cmd in commands]
installing.append(module_name)
try:
for cmd in commands:
print(">", cmd)
if os.system(cmd) != 0:
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
finally:
installing.remove(module_name)
def uninstall(module_name):
if not is_installed(module_name):
log.info(f"{module_name} hasn't been installed!")
return
if module_name not in manifest:
raise RuntimeError(f"Can't uninstall unknown package: {module_name}!")
commands = manifest[module_name]["uninstall"]
commands = [f"python -m pip uninstall -y {cmd}" for cmd in commands]
for cmd in commands:
print(">", cmd)
if os.system(cmd) != 0:
raise RuntimeError(f"Error while running {cmd}. Please check the logs in the command-line.")
def version(module_name: str) -> str:
try:
return pkg_version(module_name)
except:
return None

View File

@ -0,0 +1,53 @@
"""
A runtime that runs on a specific device (in a thread).
It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context.
This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function.
"""
from easydiffusion import device_manager
from easydiffusion.utils import log
from sdkit import Context
from sdkit.utils import get_device_usage
context = Context() # thread-local
"""
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device):
"""
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
"""
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
context.model_load_errors = {}
context.enable_codeformer = True
from easydiffusion import app
app_config = app.getConfig()
context.test_diffusers = (
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
)
log.info("Device usage during initialization:")
get_device_usage(device, log_info=True, process_usage_only=False)
device_manager.device_init(context, device)
def set_vram_optimizations(context: Context):
from easydiffusion import app
config = app.getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level
return True
return False

View File

@ -8,8 +8,17 @@ import os
import traceback import traceback
from typing import List, Union from typing import List, Union
from easydiffusion import app, model_manager, task_manager from easydiffusion import app, model_manager, task_manager, package_manager
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData from easydiffusion.tasks import RenderTask, FilterTask
from easydiffusion.types import (
GenerateImageRequest,
FilterImageRequest,
MergeRequest,
TaskData,
ModelsData,
OutputFormatData,
convert_legacy_render_req_to_new,
)
from easydiffusion.utils import log from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -97,6 +106,10 @@ def init():
def render(req: dict): def render(req: dict):
return render_internal(req) return render_internal(req)
@server_api.post("/filter")
def render(req: dict):
return filter_internal(req)
@server_api.post("/model/merge") @server_api.post("/model/merge")
def model_merge(req: dict): def model_merge(req: dict):
print(req) print(req)
@ -122,6 +135,10 @@ def init():
def stop_cloudflare_tunnel(req: dict): def stop_cloudflare_tunnel(req: dict):
return stop_cloudflare_tunnel_internal(req) return stop_cloudflare_tunnel_internal(req)
@server_api.post("/package/{package_name:str}")
def modify_package(package_name: str, req: dict):
return modify_package_internal(package_name, req)
@server_api.get("/") @server_api.get("/")
def read_root(): def read_root():
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS) return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
@ -213,24 +230,36 @@ def ping_internal(session_id: str = None):
if task_manager.current_state_error: if task_manager.current_state_error:
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail="Render thread is dead.") raise HTTPException(status_code=500, detail="Render thread is dead.")
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive # Alive
response = {"status": str(task_manager.current_state)} response = {"status": str(task_manager.current_state)}
if session_id: if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True) session = task_manager.get_cached_session(session_id, update_ttl=True)
response["tasks"] = {id(t): t.status for t in session.tasks} response["tasks"] = {id(t): t.status for t in session.tasks}
response["devices"] = task_manager.get_devices() response["devices"] = task_manager.get_devices()
response["packages_installed"] = package_manager.get_installed_packages()
response["packages_installing"] = package_manager.installing
if cloudflare.address != None: if cloudflare.address != None:
response["cloudflare"] = cloudflare.address response["cloudflare"] = cloudflare.address
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
def render_internal(req: dict): def render_internal(req: dict):
try: try:
req = convert_legacy_render_req_to_new(req)
# separate out the request data into rendering and task-specific data # separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req) task_data: TaskData = TaskData.parse_obj(req)
models_data: ModelsData = ModelsData.parse_obj(req)
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
# Overwrite user specified save path # Overwrite user specified save path
config = app.getConfig() config = app.getConfig()
@ -240,28 +269,53 @@ def render_internal(req: dict):
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
app.save_to_config( app.save_to_config(
task_data.use_stable_diffusion_model, models_data.model_paths.get("stable-diffusion"),
task_data.use_vae_model, models_data.model_paths.get("vae"),
task_data.use_hypernetwork_model, models_data.model_paths.get("hypernetwork"),
task_data.vram_usage_level, task_data.vram_usage_level,
) )
# enqueue the task # enqueue the task
new_task = task_manager.render(render_req, task_data) task = RenderTask(render_req, task_data, models_data, output_format)
return enqueue_task(task)
except HTTPException as e:
raise e
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def filter_internal(req: dict):
try:
session_id = req.get("session_id", "session")
filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req)
models_data: ModelsData = ModelsData.parse_obj(req)
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
# enqueue the task
task = FilterTask(filter_req, session_id, models_data, output_format)
return enqueue_task(task)
except HTTPException as e:
raise e
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def enqueue_task(task):
try:
task_manager.enqueue_task(task)
response = { response = {
"status": str(task_manager.current_state), "status": str(task_manager.current_state),
"queue": len(task_manager.tasks_queue), "queue": len(task_manager.tasks_queue),
"stream": f"/image/stream/{id(new_task)}", "stream": f"/image/stream/{task.id}",
"task": id(new_task), "task": task.id,
} }
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def model_merge_internal(req: dict): def model_merge_internal(req: dict):
@ -381,3 +435,19 @@ def stop_cloudflare_tunnel_internal(req: dict):
log.error(str(e)) log.error(str(e))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
def modify_package_internal(package_name: str, req: dict):
try:
cmd = req["command"]
if cmd not in ("install", "uninstall"):
raise RuntimeError(f"Unknown command: {cmd}")
cmd = getattr(package_manager, cmd)
cmd(package_name)
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
except Exception as e:
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))

View File

@ -17,7 +17,7 @@ from typing import Any, Hashable
import torch import torch
from easydiffusion import device_manager from easydiffusion import device_manager
from easydiffusion.types import GenerateImageRequest, TaskData from easydiffusion.tasks import Task
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit.utils import gc from sdkit.utils import gc
@ -27,6 +27,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
MAX_OVERLOAD_ALLOWED_RATIO = 2 # i.e. 2x pending tasks compared to the number of render threads
class SymbolClass(type): # Print nicely formatted Symbol names. class SymbolClass(type): # Print nicely formatted Symbol names.
@ -58,46 +59,6 @@ class ServerStates:
pass pass
class RenderTask: # Task with output queue and completion lock.
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
task_data.request_id = id(self)
self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except queue.Empty as e:
yield
@property
def status(self):
if self.lock.locked():
return "running"
if isinstance(self.error, StopAsyncIteration):
return "stopped"
if self.error:
return "error"
if not self.buffer_queue.empty():
return "buffer"
if self.response:
return "completed"
return "pending"
@property
def is_pending(self):
return bool(not self.response and not self.error)
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache: class DataCache:
def __init__(self): def __init__(self):
@ -123,8 +84,8 @@ class DataCache:
# Remove Items # Remove Items
for key in to_delete: for key in to_delete:
(_, val) = self._base[key] (_, val) = self._base[key]
if isinstance(val, RenderTask): if isinstance(val, Task):
log.debug(f"RenderTask {key} expired. Data removed.") log.debug(f"Task {key} expired. Data removed.")
elif isinstance(val, SessionState): elif isinstance(val, SessionState):
log.debug(f"Session {key} expired. Data removed.") log.debug(f"Session {key} expired. Data removed.")
else: else:
@ -220,8 +181,8 @@ class SessionState:
tasks.append(task) tasks.append(task)
return tasks return tasks
def put(self, task, ttl=TASK_TTL): def put(self, task: Task, ttl=TASK_TTL):
task_id = id(task) task_id = task.id
self._tasks_ids.append(task_id) self._tasks_ids.append(task_id)
if not task_cache.put(task_id, task, ttl): if not task_cache.put(task_id, task, ttl):
return False return False
@ -230,11 +191,16 @@ class SessionState:
return True return True
def keep_task_alive(task: Task):
task_cache.keep(task.id, TASK_TTL)
session_cache.keep(task.session_id, TASK_TTL)
def thread_get_next_task(): def thread_get_next_task():
from easydiffusion import renderer from easydiffusion import runtime
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.") log.warn(f"Render thread on device: {runtime.context.device} failed to acquire manager lock.")
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -242,7 +208,7 @@ def thread_get_next_task():
task = None task = None
try: # Select a render task. try: # Select a render task.
for queued_task in tasks_queue: for queued_task in tasks_queue:
if queued_task.render_device and renderer.context.device != queued_task.render_device: if queued_task.render_device and runtime.context.device != queued_task.render_device:
# Is asking for a specific render device. # Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0: if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one. continue # requested device alive, skip current one.
@ -251,7 +217,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + " is not currently active.") queued_task.error = Exception(queued_task.render_device + " is not currently active.")
task = queued_task task = queued_task
break break
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1: if not queued_task.render_device and runtime.context.device == "cpu" and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive. # not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
@ -266,19 +232,19 @@ def thread_get_next_task():
def thread_render(device): def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from easydiffusion import model_manager, renderer from easydiffusion import model_manager, runtime
try: try:
renderer.init(device) runtime.init(device)
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
"device": renderer.context.device, "device": runtime.context.device,
"device_name": renderer.context.device_name, "device_name": runtime.context.device_name,
"alive": True, "alive": True,
} }
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
model_manager.load_default_models(renderer.context) model_manager.load_default_models(runtime.context)
current_state = ServerStates.Online current_state = ServerStates.Online
except Exception as e: except Exception as e:
@ -290,8 +256,8 @@ def thread_render(device):
session_cache.clean() session_cache.clean()
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]["alive"]: if not weak_thread_data[threading.current_thread()]["alive"]:
log.info(f"Shutting down thread for device {renderer.context.device}") log.info(f"Shutting down thread for device {runtime.context.device}")
model_manager.unload_all(renderer.context) model_manager.unload_all(runtime.context)
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
@ -311,62 +277,31 @@ def thread_render(device):
task.response = {"status": "failed", "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}") log.info(f"Session {task.session_id} starting task {task.id} on {runtime.context.device_name}")
if not task.lock.acquire(blocking=False): if not task.lock.acquire(blocking=False):
raise Exception("Got locked task from queue.") raise Exception("Got locked task from queue.")
try: try:
task.run()
def step_callback():
global current_state_error
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
if (
isinstance(current_state_error, SystemExit)
or isinstance(current_state_error, StopAsyncIteration)
or isinstance(task.error, StopAsyncIteration)
):
renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
model_manager.fail_if_models_did_not_load(renderer.context)
current_state = ServerStates.Rendering
task.response = renderer.make_images(
task.render_request,
task.task_data,
task.buffer_queue,
task.temp_images,
step_callback,
)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) keep_task_alive(task)
session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = str(e) task.error = str(e)
task.response = {"status": "failed", "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
finally: finally:
gc(renderer.context) gc(runtime.context)
task.lock.release() task.lock.release()
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL) keep_task_alive(task)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!") log.info(f"Session {task.session_id} task {task.id} cancelled!")
elif task.error is not None: elif task.error is not None:
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!") log.info(f"Session {task.session_id} task {task.id} failed!")
else: else:
log.info( log.info(f"Session {task.session_id} task {task.id} completed by {runtime.context.device_name}.")
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
)
current_state = ServerStates.Online current_state = ServerStates.Online
@ -438,6 +373,12 @@ def get_devices():
finally: finally:
manager_lock.release() manager_lock.release()
# temp until TRT releases
import os
from easydiffusion import app
devices["enable_trt"] = os.path.exists(os.path.join(app.ROOT_DIR, "tensorrt"))
return devices return devices
@ -548,28 +489,27 @@ def shutdown_event(): # Signal render thread to close on shutdown
current_state_error = SystemExit("Application shutting down.") current_state_error = SystemExit("Application shutting down.")
def render(render_req: GenerateImageRequest, task_data: TaskData): def enqueue_task(task: Task):
current_thread_count = is_alive() current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError("Rendering thread has died.") raise ChildProcessError("Rendering thread has died.")
# Alive, check if task in cache # Alive, check if task in cache
session = get_cached_session(task_data.session_id, update_ttl=True) session = get_cached_session(task.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks): if len(pending_tasks) > current_thread_count * MAX_OVERLOAD_ALLOWED_RATIO:
raise ConnectionRefusedError( raise ConnectionRefusedError(
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}." f"Session {task.session_id} already has {len(pending_tasks)} pending tasks, with {current_thread_count} workers."
) )
new_task = RenderTask(render_req, task_data) if session.put(task, TASK_TTL):
if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests. # Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would. # Tries to force session.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2): if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try: try:
tasks_queue.append(new_task) tasks_queue.append(task)
idle_event.set() idle_event.set()
return new_task return task
finally: finally:
manager_lock.release() manager_lock.release()
raise RuntimeError("Failed to add task to cache.") raise RuntimeError("Failed to add task to cache.")

View File

@ -0,0 +1,3 @@
from .task import Task
from .render_images import RenderTask
from .filter_images import FilterTask

View File

@ -0,0 +1,110 @@
import json
import pprint
from sdkit.filter import apply_filters
from sdkit.models import load_model
from sdkit.utils import img_to_base64_str, log
from easydiffusion import model_manager, runtime
from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData
from .task import Task
class FilterTask(Task):
"For applying filters to input images"
def __init__(
self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData
):
super().__init__(session_id)
self.request = req
self.models_data = models_data
self.output_format = output_format
# convert to multi-filter format, if necessary
if isinstance(req.filter, str):
req.filter_params = {req.filter: req.filter_params}
req.filter = [req.filter]
if not isinstance(req.image, list):
req.image = [req.image]
def run(self):
"Runs the image filtering task on the assigned thread"
context = runtime.context
model_manager.resolve_model_paths(self.models_data)
model_manager.reload_models_if_necessary(context, self.models_data)
model_manager.fail_if_models_did_not_load(context)
print_task_info(self.request, self.models_data, self.output_format)
images = filter_images(context, self.request.image, self.request.filter, self.request.filter_params)
output_format = self.output_format
images = [
img_to_base64_str(
img, output_format.output_format, output_format.output_quality, output_format.output_lossless
)
for img in images
]
res = FilterImageResponse(self.request, self.models_data, images=images)
res = res.json()
self.buffer_queue.put(json.dumps(res))
log.info("Filter task completed")
self.response = res
def filter_images(context, images, filters, filter_params={}):
filters = filters if isinstance(filters, list) else [filters]
for filter_name in filters:
params = filter_params.get(filter_name, {})
previous_state = before_filter(context, filter_name, params)
try:
images = apply_filters(context, filter_name, images, **params)
finally:
after_filter(context, filter_name, params, previous_state)
return images
def before_filter(context, filter_name, filter_params):
if filter_name == "codeformer":
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
upscale_faces = filter_params.get("upscale_faces", False)
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths.get("realesrgan")
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
return prev_realesrgan_path
def after_filter(context, filter_name, filter_params, previous_state):
if filter_name == "codeformer":
prev_realesrgan_path = previous_state
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData):
req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}")

View File

@ -3,70 +3,115 @@ import pprint
import queue import queue
import time import time
from easydiffusion import device_manager from easydiffusion import model_manager, runtime
from easydiffusion.types import GenerateImageRequest from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData
from easydiffusion.types import Image as ResponseImage from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import Response, TaskData, UserInitiatedStop from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
from easydiffusion.utils import get_printable_request, log, save_images_to_disk from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit import Context
from sdkit.filter import apply_filters
from sdkit.generate import generate_images from sdkit.generate import generate_images
from sdkit.models import load_model
from sdkit.utils import ( from sdkit.utils import (
diffusers_latent_samples_to_images, diffusers_latent_samples_to_images,
gc, gc,
img_to_base64_str, img_to_base64_str,
img_to_buffer, img_to_buffer,
latent_samples_to_images, latent_samples_to_images,
get_device_usage,
) )
context = Context() # thread-local from .task import Task
""" from .filter_images import filter_images
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device): class RenderTask(Task):
""" "For image generation"
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
"""
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
context.model_load_errors = {}
context.enable_codeformer = True
from easydiffusion import app def __init__(
self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
):
super().__init__(task_data.session_id)
app_config = app.getConfig() task_data.request_id = self.id
context.test_diffusers = ( self.render_request: GenerateImageRequest = req # Initial Request
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main" self.task_data: TaskData = task_data
) self.models_data = models_data
self.output_format = output_format
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
log.info("Device usage during initialization:") def run(self):
get_device_usage(device, log_info=True, process_usage_only=False) "Runs the image generation task on the assigned thread"
device_manager.device_init(context, device) from easydiffusion import task_manager
context = runtime.context
def step_callback():
task_manager.keep_task_alive(self)
task_manager.current_state = task_manager.ServerStates.Rendering
if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance(
self.error, StopAsyncIteration
):
context.stop_processing = True
if isinstance(task_manager.current_state_error, StopAsyncIteration):
self.error = task_manager.current_state_error
task_manager.current_state_error = None
log.info(f"Session {self.session_id} sent cancel signal for task {self.id}")
task_manager.current_state = task_manager.ServerStates.LoadingModel
model_manager.resolve_model_paths(self.models_data)
models_to_force_reload = []
if (
runtime.set_vram_optimizations(context)
or self.has_param_changed(context, "clip_skip")
or self.has_param_changed(context, "convert_to_tensorrt")
):
models_to_force_reload.append("stable-diffusion")
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
model_manager.fail_if_models_did_not_load(context)
task_manager.current_state = task_manager.ServerStates.Rendering
self.response = make_images(
context,
self.render_request,
self.task_data,
self.models_data,
self.output_format,
self.buffer_queue,
self.temp_images,
step_callback,
)
def has_param_changed(self, context, param_name):
if not context.test_diffusers:
return False
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
return True
model = context.models["stable-diffusion"]
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
return model["params"].get(param_name) != new_val
def make_images( def make_images(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
): ):
context.stop_processing = False context.stop_processing = False
print_task_info(req, task_data) print_task_info(req, task_data, models_data, output_format)
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) images, seeds = make_images_internal(
context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback
)
res = Response( res = GenerateImageResponse(
req, req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format)
task_data,
images=construct_response(images, seeds, task_data, base_seed=req.seed),
) )
res = res.json() res = res.json()
data_queue.put(json.dumps(res)) data_queue.put(json.dumps(res))
@ -75,21 +120,32 @@ def make_images(
return res return res
def print_task_info(req: GenerateImageRequest, task_data: TaskData): def print_task_info(
req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[") req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
):
req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
log.info(f"request: {req_str}") log.info(f"request: {req_str}")
log.info(f"task data: {task_str}") log.info(f"task data: {task_str}")
# log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}")
def make_images_internal( def make_images_internal(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
): ):
images, user_stopped = generate_images_internal( images, user_stopped = generate_images_internal(
context,
req, req,
task_data, task_data,
data_queue, data_queue,
@ -98,11 +154,14 @@ def make_images_internal(
task_data.stream_image_progress, task_data.stream_image_progress,
task_data.stream_image_progress_interval, task_data.stream_image_progress_interval,
) )
gc(context) gc(context)
filtered_images = filter_images(req, task_data, images, user_stopped)
filters, filter_params = task_data.filters, task_data.filter_params
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
if task_data.save_to_disk_path is not None: if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data) save_images_to_disk(images, filtered_images, req, task_data, output_format)
seeds = [*range(req.seed, req.seed + len(images))] seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images: if task_data.show_only_filtered_image or filtered_images is images:
@ -112,6 +171,7 @@ def make_images_internal(
def generate_images_internal( def generate_images_internal(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
data_queue: queue.Queue, data_queue: queue.Queue,
@ -123,6 +183,7 @@ def generate_images_internal(
context.temp_images.clear() context.temp_images.clear()
callback = make_step_callback( callback = make_step_callback(
context,
req, req,
task_data, task_data,
data_queue, data_queue,
@ -155,65 +216,14 @@ def generate_images_internal(
return images, user_stopped return images, user_stopped
def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped): def construct_response(images: list, seeds: list, output_format: OutputFormatData):
if user_stopped:
return images
if task_data.block_nsfw:
images = apply_filters(context, "nsfw_checker", images)
if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths["realesrgan"]
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
try:
images = apply_filters(
context,
"codeformer",
images,
upscale_faces=task_data.codeformer_upscale_faces,
codeformer_fidelity=task_data.codeformer_fidelity,
)
finally:
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
images = apply_filters(context, "gfpgan", images)
if task_data.use_upscale:
if "realesrgan" in task_data.use_upscale.lower():
images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount)
elif task_data.use_upscale == "latent_upscaler":
images = apply_filters(
context,
"latent_upscaler",
images,
scale=task_data.upscale_amount,
latent_upscaler_options={
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"seed": req.seed,
"num_inference_steps": task_data.latent_upscaler_steps,
"guidance_scale": 0,
},
)
return images
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
return [ return [
ResponseImage( ResponseImage(
data=img_to_base64_str( data=img_to_base64_str(
img, img,
task_data.output_format, output_format.output_format,
task_data.output_quality, output_format.output_quality,
task_data.output_lossless, output_format.output_lossless,
), ),
seed=seed, seed=seed,
) )
@ -222,6 +232,7 @@ def construct_response(images: list, seeds: list, task_data: TaskData, base_seed
def make_step_callback( def make_step_callback(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
data_queue: queue.Queue, data_queue: queue.Queue,
@ -242,7 +253,7 @@ def make_step_callback(
images = latent_samples_to_images(context, x_samples) images = latent_samples_to_images(context, x_samples)
if task_data.block_nsfw: if task_data.block_nsfw:
images = apply_filters(context, "nsfw_checker", images) images = filter_images(context, images, "nsfw_checker")
for i, img in enumerate(images): for i, img in enumerate(images):
buf = img_to_buffer(img, output_format="JPEG") buf = img_to_buffer(img, output_format="JPEG")

View File

@ -0,0 +1,47 @@
from threading import Lock
from queue import Queue, Empty as EmptyQueueException
from typing import Any
class Task:
"Task with output queue and completion lock"
def __init__(self, session_id):
self.id = id(self)
self.session_id = session_id
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.error: Exception = None
self.lock: Lock = Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: Queue = Queue() # Queue of JSON string segments
self.response: Any = None # Copy of the last reponse
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except EmptyQueueException as e:
yield
@property
def status(self):
if self.lock.locked():
return "running"
if isinstance(self.error, StopAsyncIteration):
return "stopped"
if self.error:
return "error"
if not self.buffer_queue.empty():
return "buffer"
if self.response:
return "completed"
return "pending"
@property
def is_pending(self):
return bool(not self.response and not self.error)
def run(self):
"Override this to implement the task's behavior"
pass

View File

@ -1,4 +1,4 @@
from typing import Any, List, Union from typing import Any, List, Dict, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -17,8 +17,11 @@ class GenerateImageRequest(BaseModel):
init_image: Any = None init_image: Any = None
init_image_mask: Any = None init_image_mask: Any = None
control_image: Any = None
control_alpha: Union[float, List[float]] = None
prompt_strength: float = 0.8 prompt_strength: float = 0.8
preserve_init_image_color_profile = False preserve_init_image_color_profile = False
strict_mask_border = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
hypernetwork_strength: float = 0 hypernetwork_strength: float = 0
@ -26,6 +29,35 @@ class GenerateImageRequest(BaseModel):
tiling: str = "none" # "none", "x", "y", "xy" tiling: str = "none" # "none", "x", "y", "xy"
class FilterImageRequest(BaseModel):
image: Any = None
filter: Union[str, List[str]] = None
filter_params: dict = {}
class ModelsData(BaseModel):
"""
Contains the information related to the models involved in a request.
- To load a model: set the relative path(s) to the model in `model_paths`. No effect if already loaded.
- To unload a model: set the model to `None` in `model_paths`. No effect if already unloaded.
Models that aren't present in `model_paths` will not be changed.
"""
model_paths: Dict[str, Union[str, None, List[str]]] = None
"model_type to string path, or list of string paths"
model_params: Dict[str, Dict[str, Any]] = {}
"model_type to dict of parameters"
class OutputFormatData(BaseModel):
output_format: str = "jpeg" # or "png" or "webp"
output_quality: int = 75
output_lossless: bool = False
class TaskData(BaseModel): class TaskData(BaseModel):
request_id: str = None request_id: str = None
session_id: str = "session" session_id: str = "session"
@ -40,12 +72,12 @@ class TaskData(BaseModel):
use_vae_model: Union[str, List[str]] = None use_vae_model: Union[str, List[str]] = None
use_hypernetwork_model: Union[str, List[str]] = None use_hypernetwork_model: Union[str, List[str]] = None
use_lora_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None
use_controlnet_model: Union[str, List[str]] = None
filters: List[str] = []
filter_params: Dict[str, Dict[str, Any]] = {}
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
block_nsfw: bool = False block_nsfw: bool = False
output_format: str = "jpeg" # or "png" or "webp"
output_quality: int = 75
output_lossless: bool = False
metadata_output_format: str = "txt" # or "json" metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False stream_image_progress: bool = False
stream_image_progress_interval: int = 5 stream_image_progress_interval: int = 5
@ -80,24 +112,38 @@ class Image:
} }
class Response: class GenerateImageResponse:
render_request: GenerateImageRequest render_request: GenerateImageRequest
task_data: TaskData task_data: TaskData
models_data: ModelsData
images: list images: list
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list): def __init__(
self,
render_request: GenerateImageRequest,
task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
images: list,
):
self.render_request = render_request self.render_request = render_request
self.task_data = task_data self.task_data = task_data
self.models_data = models_data
self.output_format = output_format
self.images = images self.images = images
def json(self): def json(self):
del self.render_request.init_image del self.render_request.init_image
del self.render_request.init_image_mask del self.render_request.init_image_mask
task_data = self.task_data.dict()
task_data.update(self.output_format.dict())
res = { res = {
"status": "succeeded", "status": "succeeded",
"render_request": self.render_request.dict(), "render_request": self.render_request.dict(),
"task_data": self.task_data.dict(), "task_data": task_data,
# "models_data": self.models_data.dict(), # haven't migrated the UI to the new format (yet)
"output": [], "output": [],
} }
@ -107,5 +153,105 @@ class Response:
return res return res
class FilterImageResponse:
request: FilterImageRequest
models_data: ModelsData
images: list
def __init__(self, request: FilterImageRequest, models_data: ModelsData, images: list):
self.request = request
self.models_data = models_data
self.images = images
def json(self):
del self.request.image
res = {
"status": "succeeded",
"request": self.request.dict(),
"models_data": self.models_data.dict(),
"output": [],
}
for image in self.images:
res["output"].append(image)
return res
class UserInitiatedStop(Exception): class UserInitiatedStop(Exception):
pass pass
def convert_legacy_render_req_to_new(old_req: dict):
new_req = dict(old_req)
# new keys
model_paths = new_req["model_paths"] = {}
model_params = new_req["model_params"] = {}
filters = new_req["filters"] = []
filter_params = new_req["filter_params"] = {}
# move the model info
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
model_paths["vae"] = old_req.get("use_vae_model")
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
model_paths["lora"] = old_req.get("use_lora_model")
model_paths["controlnet"] = old_req.get("use_controlnet_model")
model_paths["gfpgan"] = old_req.get("use_face_correction", "")
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None
model_paths["codeformer"] = old_req.get("use_face_correction", "")
model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None
model_paths["realesrgan"] = old_req.get("use_upscale", "")
model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None
model_paths["latent_upscaler"] = old_req.get("use_upscale", "")
model_paths["latent_upscaler"] = (
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
)
if old_req.get("block_nsfw"):
model_paths["nsfw_checker"] = "nsfw_checker"
# move the model params
if model_paths["stable-diffusion"]:
model_params["stable-diffusion"] = {
"clip_skip": bool(old_req.get("clip_skip", False)),
"convert_to_tensorrt": bool(old_req.get("convert_to_tensorrt", False)),
}
# move the filter params
if model_paths["realesrgan"]:
filter_params["realesrgan"] = {"scale": int(old_req.get("upscale_amount", 4))}
if model_paths["latent_upscaler"]:
filter_params["latent_upscaler"] = {
"prompt": old_req["prompt"],
"negative_prompt": old_req.get("negative_prompt"),
"seed": int(old_req.get("seed", 42)),
"num_inference_steps": int(old_req.get("latent_upscaler_steps", 10)),
"guidance_scale": 0,
}
if model_paths["codeformer"]:
filter_params["codeformer"] = {
"upscale_faces": bool(old_req.get("codeformer_upscale_faces", True)),
"codeformer_fidelity": float(old_req.get("codeformer_fidelity", 0.5)),
}
# set the filters
if old_req.get("block_nsfw"):
filters.append("nsfw_checker")
if model_paths["codeformer"]:
filters.append("codeformer")
elif model_paths["gfpgan"]:
filters.append("gfpgan")
if model_paths["realesrgan"]:
filters.append("realesrgan")
elif model_paths["latent_upscaler"]:
filters.append("latent_upscaler")
return new_req

View File

@ -7,7 +7,7 @@ from datetime import datetime
from functools import reduce from functools import reduce
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import GenerateImageRequest, TaskData from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
from numpy import base_repr from numpy import base_repr
from sdkit.utils import save_dicts, save_images from sdkit.utils import save_dicts, save_images
@ -114,12 +114,14 @@ def format_file_name(
return filename_regex.sub("_", format) return filename_regex.sub("_", format)
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): def save_images_to_disk(
images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData
):
now = time.time() now = time.time()
app_config = app.getConfig() app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id") folder_format = app_config.get("folder_format", "$id")
save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
metadata_entries = get_metadata_entries_for_request(req, task_data) metadata_entries = get_metadata_entries_for_request(req, task_data, output_format)
file_number = calculate_img_number(save_dir_path, task_data) file_number = calculate_img_number(save_dir_path, task_data)
make_filename = make_filename_callback( make_filename = make_filename_callback(
app_config.get("filename_format", "$p_$tsb64"), app_config.get("filename_format", "$p_$tsb64"),
@ -134,9 +136,9 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
filtered_images, filtered_images,
save_dir_path, save_dir_path,
file_name=make_filename, file_name=make_filename,
output_format=task_data.output_format, output_format=output_format.output_format,
output_quality=task_data.output_quality, output_quality=output_format.output_quality,
output_lossless=task_data.output_lossless, output_lossless=output_format.output_lossless,
) )
if task_data.metadata_output_format: if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","): for metadata_output_format in task_data.metadata_output_format.split(","):
@ -146,7 +148,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
save_dir_path, save_dir_path,
file_name=make_filename, file_name=make_filename,
output_format=metadata_output_format, output_format=metadata_output_format,
file_format=task_data.output_format, file_format=output_format.output_format,
) )
else: else:
make_filter_filename = make_filename_callback( make_filter_filename = make_filename_callback(
@ -162,17 +164,17 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
images, images,
save_dir_path, save_dir_path,
file_name=make_filename, file_name=make_filename,
output_format=task_data.output_format, output_format=output_format.output_format,
output_quality=task_data.output_quality, output_quality=output_format.output_quality,
output_lossless=task_data.output_lossless, output_lossless=output_format.output_lossless,
) )
save_images( save_images(
filtered_images, filtered_images,
save_dir_path, save_dir_path,
file_name=make_filter_filename, file_name=make_filter_filename,
output_format=task_data.output_format, output_format=output_format.output_format,
output_quality=task_data.output_quality, output_quality=output_format.output_quality,
output_lossless=task_data.output_lossless, output_lossless=output_format.output_lossless,
) )
if task_data.metadata_output_format: if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","): for metadata_output_format in task_data.metadata_output_format.split(","):
@ -181,20 +183,21 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
metadata_entries, metadata_entries,
save_dir_path, save_dir_path,
file_name=make_filter_filename, file_name=make_filter_filename,
output_format=task_data.metadata_output_format, output_format=metadata_output_format,
file_format=task_data.output_format, file_format=output_format.output_format,
) )
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
metadata = get_printable_request(req, task_data) metadata = get_printable_request(req, task_data, output_format)
# if text, format it in the text format expected by the UI # if text, format it in the text format expected by the UI
is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",") is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",")
if is_txt_format: if is_txt_format:
def format_value(value): def format_value(value):
if isinstance(value, list): if isinstance(value, list):
return ", ".join([ str(it) for it in value ]) return ", ".join([str(it) for it in value])
return value return value
metadata = { metadata = {
@ -208,9 +211,10 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
return entries return entries
def get_printable_request(req: GenerateImageRequest, task_data: TaskData): def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
req_metadata = req.dict() req_metadata = req.dict()
task_data_metadata = task_data.dict() task_data_metadata = task_data.dict()
task_data_metadata.update(output_format.dict())
app_config = app.getConfig() app_config = app.getConfig()
using_diffusers = app_config.get("test_diffusers", False) using_diffusers = app_config.get("test_diffusers", False)
@ -222,8 +226,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
metadata[key] = req_metadata[key] metadata[key] = req_metadata[key]
elif key in task_data_metadata: elif key in task_data_metadata:
metadata[key] = task_data_metadata[key] metadata[key] = task_data_metadata[key]
elif key is "use_embedding_models" and using_diffusers: elif key == "use_embedding_models" and using_diffusers:
embeddings_extensions = {".pt", ".bin", ".safetensors"} embeddings_extensions = {".pt", ".bin", ".safetensors"}
def scan_directory(directory_path: str): def scan_directory(directory_path: str):
used_embeddings = [] used_embeddings = []
for entry in os.scandir(directory_path): for entry in os.scandir(directory_path):
@ -232,15 +237,18 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
if entry_extension not in embeddings_extensions: if entry_extension not in embeddings_extensions:
continue continue
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])") embedding_name_regex = regex.compile(
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
)
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt): if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
used_embeddings.append(entry.path) used_embeddings.append(entry.path)
elif entry.is_dir(): elif entry.is_dir():
used_embeddings.extend(scan_directory(entry.path)) used_embeddings.extend(scan_directory(entry.path))
return used_embeddings return used_embeddings
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings")) used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
# Clean up the metadata # Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata: if req.init_image is None and "prompt_strength" in metadata:
del metadata["prompt_strength"] del metadata["prompt_strength"]
@ -254,7 +262,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
del metadata["latent_upscaler_steps"] del metadata["latent_upscaler_steps"]
if not using_diffusers: if not using_diffusers:
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata): for key in (
x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata
):
del metadata[key] del metadata[key]
return metadata return metadata

View File

@ -17,6 +17,7 @@
<link rel="stylesheet" href="/media/css/searchable-models.css"> <link rel="stylesheet" href="/media/css/searchable-models.css">
<link rel="stylesheet" href="/media/css/image-modal.css"> <link rel="stylesheet" href="/media/css/image-modal.css">
<link rel="stylesheet" href="/media/css/plugins.css"> <link rel="stylesheet" href="/media/css/plugins.css">
<link rel="stylesheet" href="/media/css/animations.css">
<link rel="manifest" href="/media/manifest.webmanifest"> <link rel="manifest" href="/media/manifest.webmanifest">
<script src="/media/js/jquery-3.6.1.min.js"></script> <script src="/media/js/jquery-3.6.1.min.js"></script>
<script src="/media/js/jquery-confirm.min.js"></script> <script src="/media/js/jquery-confirm.min.js"></script>
@ -31,7 +32,7 @@
<h1> <h1>
<img id="logo_img" src="/media/images/icon-512x512.png" > <img id="logo_img" src="/media/images/icon-512x512.png" >
Easy Diffusion Easy Diffusion
<small><span id="version">v2.5.45</span> <span id="updateBranchLabel"></span></small> <small><span id="version">v2.5.47</span> <span id="updateBranchLabel"></span></small>
</h1> </h1>
</div> </div>
<div id="server-status"> <div id="server-status">
@ -108,6 +109,7 @@
</div> </div>
<div id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Preserve color profile <small>(helps during inpainting)</small></label></div> <div id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Preserve color profile <small>(helps during inpainting)</small></label></div>
<div id="strict_mask_border_setting" class="pl-5"><input id="strict_mask_border" name="strict_mask_border" type="checkbox"> <label for="strict_mask_border">Strict Mask Border <small>(won't modify outside the mask, but the mask border might be visible)</small></label></div>
</div> </div>
@ -145,6 +147,14 @@
<button id="reload-models" class="secondaryButton reloadModels"><i class='fa-solid fa-rotate'></i></button> <button id="reload-models" class="secondaryButton reloadModels"><i class='fa-solid fa-rotate'></i></button>
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
</td></tr> </td></tr>
<tr class="pl-5 displayNone" id="enable_trt_config">
<td><label for="convert_to_tensorrt">Convert to TensorRT:</label></td>
<td class="diffusers-restart-needed">
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>
<label><small>Takes upto 20 mins the first time</small></label>
</td>
</tr>
<tr class="pl-5 displayNone" id="clip_skip_config"> <tr class="pl-5 displayNone" id="clip_skip_config">
<td><label for="clip_skip">Clip Skip:</label></td> <td><label for="clip_skip">Clip Skip:</label></td>
<td class="diffusers-restart-needed"> <td class="diffusers-restart-needed">
@ -171,7 +181,6 @@
<option value="dpmpp_2m">DPM++ 2m (Karras)</option> <option value="dpmpp_2m">DPM++ 2m (Karras)</option>
<option value="dpmpp_2m_sde" class="diffusers-only">DPM++ 2m SDE (Karras)</option> <option value="dpmpp_2m_sde" class="diffusers-only">DPM++ 2m SDE (Karras)</option>
<option value="dpmpp_sde">DPM++ SDE (Karras)</option> <option value="dpmpp_sde">DPM++ SDE (Karras)</option>
<option value="dpm_fast" class="k_diffusion-only">DPM Fast (Karras)</option>
<option value="dpm_adaptive" class="k_diffusion-only">DPM Adaptive (Karras)</option> <option value="dpm_adaptive" class="k_diffusion-only">DPM Adaptive (Karras)</option>
<option value="ddpm" class="diffusers-only">DDPM</option> <option value="ddpm" class="diffusers-only">DDPM</option>
<option value="deis" class="diffusers-only">DEIS</option> <option value="deis" class="diffusers-only">DEIS</option>
@ -183,15 +192,15 @@
</select> </select>
<a href="https://github.com/easydiffusion/easydiffusion/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
</td></tr> </td></tr>
<tr class="pl-5"><td><label>Image Size: </label></td><td> <tr class="pl-5"><td><label>Image Size: </label></td><td id="image-size-options">
<select id="width" name="width" value="512"> <select id="width" name="width" value="512">
<option value="128">128 (*)</option> <option value="128">128</option>
<option value="192">192</option> <option value="192">192</option>
<option value="256">256 (*)</option> <option value="256">256</option>
<option value="320">320</option> <option value="320">320</option>
<option value="384">384</option> <option value="384">384</option>
<option value="448">448</option> <option value="448">448</option>
<option value="512" selected>512 (*)</option> <option value="512" selected="">512 (*)</option>
<option value="576">576</option> <option value="576">576</option>
<option value="640">640</option> <option value="640">640</option>
<option value="704">704</option> <option value="704">704</option>
@ -206,14 +215,15 @@
<option value="2048">2048</option> <option value="2048">2048</option>
</select> </select>
<label for="width"><small>(width)</small></label> <label for="width"><small>(width)</small></label>
<span id="swap-width-height" class="clickable smallButton" style="margin-left: 2px; margin-right:2px;"><i class="fa-solid fa-right-left"><span class="simple-tooltip top-left"> Swap width and height </span></i></span>
<select id="height" name="height" value="512"> <select id="height" name="height" value="512">
<option value="128">128 (*)</option> <option value="128">128</option>
<option value="192">192</option> <option value="192">192</option>
<option value="256">256 (*)</option> <option value="256">256</option>
<option value="320">320</option> <option value="320">320</option>
<option value="384">384</option> <option value="384">384</option>
<option value="448">448</option> <option value="448">448</option>
<option value="512" selected>512 (*)</option> <option value="512" selected="">512 (*)</option>
<option value="576">576</option> <option value="576">576</option>
<option value="640">640</option> <option value="640">640</option>
<option value="704">704</option> <option value="704">704</option>
@ -228,6 +238,22 @@
<option value="2048">2048</option> <option value="2048">2048</option>
</select> </select>
<label for="height"><small>(height)</small></label> <label for="height"><small>(height)</small></label>
<div id="recent-resolutions-container">
<span id="recent-resolutions-button" class="clickable"><i class="fa-solid fa-sliders"><span class="simple-tooltip top-left"> Recent sizes </span></i></span>
<div id="recent-resolutions-popup" class="displayNone">
<small>Custom size:</small><br>
<input id="custom-width" name="custom-width" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)">
&times;
<input id="custom-height" name="custom-height" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)"><br>
<small>Enlarge:</small><br>
<div id="enlarge-buttons"><button id="enlarge15" class="tertiaryButton smallButton">×1.5</button>&nbsp;<button id="enlarge2" class="tertiaryButton smallButton">×2</button>&nbsp;<button id="enlarge3" class="tertiaryButton smallButton">×3</button></div>
<small>Recently&nbsp;used:</small><br>
<div id="recent-resolution-list">
</div>
</div>
</div>
<div id="small_image_warning" class="displayNone">Small image sizes can cause bad image quality</div> <div id="small_image_warning" class="displayNone">Small image sizes can cause bad image quality</div>
</td></tr> </td></tr>
<tr class="pl-5"><td><label for="num_inference_steps">Inference Steps:</label></td><td> <input id="num_inference_steps" name="num_inference_steps" type="number" min="1" step="1" style="width: 42pt" value="25" onkeypress="preventNonNumericalInput(event)"></td></tr> <tr class="pl-5"><td><label for="num_inference_steps">Inference Steps:</label></td><td> <input id="num_inference_steps" name="num_inference_steps" type="number" min="1" step="1" style="width: 42pt" value="25" onkeypress="preventNonNumericalInput(event)"></td></tr>
@ -360,6 +386,13 @@
<div class="parameters-table" id="system-settings-table"></div> <div class="parameters-table" id="system-settings-table"></div>
<br/> <br/>
<button id="save-system-settings-btn" class="primaryButton">Save</button> <button id="save-system-settings-btn" class="primaryButton">Save</button>
<div id="install-extras-container" class="displayNone">
<br/>
<div id="install-extras">
<h3><i class="fa fa-cubes-stacked"></i> Optional Packages</h3>
<div class="parameters-table" id="system-settings-install-extras-table"></div>
</div>
</div>
<br/><br/> <br/><br/>
<div id="share-easy-diffusion"> <div id="share-easy-diffusion">
<h3><i class="fa fa-user-group"></i> Share Easy Diffusion</h3> <h3><i class="fa fa-user-group"></i> Share Easy Diffusion</h3>
@ -670,7 +703,7 @@ async function init() {
events: { events: {
statusChange: setServerStatus, statusChange: setServerStatus,
idle: onIdle, idle: onIdle,
ping: tunnelUpdate ping: onPing
} }
}) })
splashScreen() splashScreen()

View File

@ -0,0 +1,68 @@
@keyframes ldio-8f673ktaleu-1 {
0% { transform: rotate(0deg) }
50% { transform: rotate(-45deg) }
100% { transform: rotate(0deg) }
}
@keyframes ldio-8f673ktaleu-2 {
0% { transform: rotate(180deg) }
50% { transform: rotate(225deg) }
100% { transform: rotate(180deg) }
}
.ldio-8f673ktaleu > div:nth-child(2) {
transform: translate(-15px,0);
}
.ldio-8f673ktaleu > div:nth-child(2) div {
position: absolute;
top: 20px;
left: 20px;
width: 60px;
height: 30px;
border-radius: 60px 60px 0 0;
background: #f3b72e;
animation: ldio-8f673ktaleu-1 1s linear infinite;
transform-origin: 30px 30px
}
.ldio-8f673ktaleu > div:nth-child(2) div:nth-child(2) {
animation: ldio-8f673ktaleu-2 1s linear infinite
}
.ldio-8f673ktaleu > div:nth-child(2) div:nth-child(3) {
transform: rotate(-90deg);
animation: none;
}@keyframes ldio-8f673ktaleu-3 {
0% { transform: translate(95px,0); opacity: 0 }
20% { opacity: 1 }
100% { transform: translate(35px,0); opacity: 1 }
}
.ldio-8f673ktaleu > div:nth-child(1) {
display: block;
}
.ldio-8f673ktaleu > div:nth-child(1) div {
position: absolute;
top: 46px;
left: -4px;
width: 8px;
height: 8px;
border-radius: 50%;
background: #3869c5;
animation: ldio-8f673ktaleu-3 1s linear infinite
}
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(1) { animation-delay: -0.67s }
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(2) { animation-delay: -0.33s }
.ldio-8f673ktaleu > div:nth-child(1) div:nth-child(3) { animation-delay: 0s }
.loadingio-spinner-bean-eater-x0y3u8qky4n {
width: 58px;
height: 58px;
display: inline-block;
overflow: hidden;
background: none;
}
.ldio-8f673ktaleu {
width: 100%;
height: 100%;
position: relative;
transform: translateZ(0) scale(0.58);
backface-visibility: hidden;
transform-origin: 0 0; /* see note above */
}
.ldio-8f673ktaleu div { box-sizing: content-box; }
/* generated by https://loading.io/ */

View File

@ -78,6 +78,7 @@
border-bottom-right-radius: 12px; border-bottom-right-radius: 12px;
} }
.parameters-table .fa-fire { .parameters-table .fa-fire,
.parameters-table .fa-bolt {
color: #F7630C; color: #F7630C;
} }

View File

@ -1665,8 +1665,14 @@ body.wait-pause {
} }
#embeddings-list button { #embeddings-list button {
margin-top: 2px; margin: 2px;
margin-bottom: 2px; color: var(--button-color);
background: var(--button-text-color);
font-weight: 700;
}
#embeddings-list button:hover {
background: var(--accent-color);
color: var(--button-text-color);
} }
#embeddings-list .collapsible { #embeddings-list .collapsible {
@ -1747,3 +1753,67 @@ body.wait-pause {
content: "Please restart Easy Diffusion!"; content: "Please restart Easy Diffusion!";
font-size: 10pt; font-size: 10pt;
} }
input#custom-width, input#custom-height {
width: 47pt;
}
div#recent-resolutions-container {
position: relative;
display:inline-block;
}
div#recent-resolutions-popup {
position: absolute;
right: 0px;
margin: 3px;
padding: 0.2em 1em 0.4em 1em;
z-index: 1;
background: var(--background-color3);
border-radius: 4px;
box-shadow: 0 20px 28px 0 rgba(0, 0, 0, 0.15), 0 6px 20px 0 rgba(0, 0, 0, 0.15);
}
div#recent-resolutions-popup small {
opacity: 0.7;
}
td#image-size-options small {
margin-right: 0px !important;
}
td#image-size-options {
white-space: nowrap;
}
div#recent-resolution-list {
text-align: center;
}
div#enlarge-buttons {
text-align: center;
}
.clickable {
cursor: pointer;
}
.imgContainer .spinner {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
margin: 0;
padding: 0;
background: var(--background-color3);
opacity: 0.95;
border-radius: 5px;
padding: 4pt;
border: 1px solid var(--button-color);
box-shadow: 0px 0px 4px black;
}
.imgContainer .spinnerStatus {
font-size: 10pt;
}

View File

@ -1047,7 +1047,9 @@
} }
} }
class FilterTask extends Task { class FilterTask extends Task {
constructor(options = {}) {} constructor(options = {}) {
super(options)
}
/** Send current task to server. /** Send current task to server.
* @param {*} [timeout=-1] Optional timeout value in ms * @param {*} [timeout=-1] Optional timeout value in ms
* @returns the response from the render request. * @returns the response from the render request.
@ -1055,9 +1057,27 @@
*/ */
async post(timeout = -1) { async post(timeout = -1) {
let jsonResponse = await super.post("/filter", timeout) let jsonResponse = await super.post("/filter", timeout)
//this._setId(jsonResponse.task) if (typeof jsonResponse?.task !== "number") {
console.warn("Endpoint error response: ", jsonResponse)
const event = Object.assign({ task: this }, jsonResponse)
await eventSource.fireEvent(EVENT_UNEXPECTED_RESPONSE, event)
if ("continueWith" in event) {
jsonResponse = await Promise.resolve(event.continueWith)
}
if (typeof jsonResponse?.task !== "number") {
const err = new Error(jsonResponse?.detail || "Endpoint response does not contains a task ID.")
this.abort(err)
throw err
}
}
this._setId(jsonResponse.task)
if (jsonResponse.stream) {
this.streamUrl = jsonResponse.stream
}
this._setStatus(TaskStatus.waiting) this._setStatus(TaskStatus.waiting)
return jsonResponse
} }
checkReqBody() {}
enqueue(progressCallback) { enqueue(progressCallback) {
return Task.enqueueNew(this, FilterTask, progressCallback) return Task.enqueueNew(this, FilterTask, progressCallback)
} }
@ -1068,6 +1088,65 @@
if (this.isStopped) { if (this.isStopped) {
return return
} }
this._setStatus(TaskStatus.pending)
progressCallback?.call(this, { reqBody: this._reqBody })
Object.freeze(this._reqBody)
// Post task request to backend
let renderRes = undefined
try {
renderRes = yield this.post()
yield progressCallback?.call(this, { renderResponse: renderRes })
} catch (e) {
yield progressCallback?.call(this, { detail: e.message })
throw e
}
try {
// Wait for task to start on server.
yield this.waitUntil({
callback: function() {
return progressCallback?.call(this, {})
},
status: TaskStatus.processing,
})
} catch (e) {
this.abort(err)
throw e
}
// Task started!
// Open the reader.
const reader = this.reader
const task = this
reader.onError = function(response) {
if (progressCallback) {
task.abort(new Error(response.statusText))
return progressCallback.call(task, { response, reader })
}
return Task.prototype.onError.call(task, response)
}
yield progressCallback?.call(this, { reader })
//Start streaming the results.
const streamGenerator = reader.open()
let value = undefined
let done = undefined
yield progressCallback?.call(this, { stream: streamGenerator })
do {
;({ value, done } = yield streamGenerator.next())
if (typeof value !== "object") {
continue
}
if (value.status !== undefined) {
yield progressCallback?.call(this, value)
if (value.status === "succeeded" || value.status === "failed") {
done = true
}
}
} while (!done)
return value
} }
static start(task, progressCallback) { static start(task, progressCallback) {
if (typeof task !== "object") { if (typeof task !== "object") {

View File

@ -626,6 +626,7 @@ class ImageEditor {
.getImageData(0, 0, this.width, this.height) .getImageData(0, 0, this.width, this.height)
.data.some((channel) => channel !== 0) .data.some((channel) => channel !== 0)
maskSetting.checked = !is_blank maskSetting.checked = !is_blank
maskSetting.dispatchEvent(new Event("change"))
} }
this.hide() this.hide()
} }

View File

@ -5,6 +5,9 @@ const MIN_GPUS_TO_SHOW_SELECTION = 2
const IMAGE_REGEX = new RegExp("data:image/[A-Za-z]+;base64") const IMAGE_REGEX = new RegExp("data:image/[A-Za-z]+;base64")
const htmlTaskMap = new WeakMap() const htmlTaskMap = new WeakMap()
const spinnerPacmanHtml =
'<div class="loadingio-spinner-bean-eater-x0y3u8qky4n"><div class="ldio-8f673ktaleu"><div><div></div><div></div><div></div></div><div><div></div><div></div><div></div></div></div></div>'
const taskConfigSetup = { const taskConfigSetup = {
taskConfig: { taskConfig: {
seed: { value: ({ seed }) => seed, label: "Seed" }, seed: { value: ({ seed }) => seed, label: "Seed" },
@ -46,6 +49,7 @@ const taskConfigSetup = {
use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model }, use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model }, lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
preserve_init_image_color_profile: "Preserve Color Profile", preserve_init_image_color_profile: "Preserve Color Profile",
strict_mask_border: "Strict Mask Border",
}, },
pluginTaskConfig: {}, pluginTaskConfig: {},
getCSSKey: (key) => getCSSKey: (key) =>
@ -74,6 +78,15 @@ let randomSeedField = document.querySelector("#random_seed")
let seedField = document.querySelector("#seed") let seedField = document.querySelector("#seed")
let widthField = document.querySelector("#width") let widthField = document.querySelector("#width")
let heightField = document.querySelector("#height") let heightField = document.querySelector("#height")
let customWidthField = document.querySelector("#custom-width")
let customHeightField = document.querySelector("#custom-height")
let recentResolutionsButton = document.querySelector("#recent-resolutions-button")
let recentResolutionsPopup = document.querySelector("#recent-resolutions-popup")
let recentResolutionList = document.querySelector("#recent-resolution-list")
let enlarge15Button = document.querySelector("#enlarge15")
let enlarge2Button = document.querySelector("#enlarge2")
let enlarge3Button = document.querySelector("#enlarge3")
let swapWidthHeightButton = document.querySelector("#swap-width-height")
let smallImageWarning = document.querySelector("#small_image_warning") let smallImageWarning = document.querySelector("#small_image_warning")
let initImageSelector = document.querySelector("#init_image") let initImageSelector = document.querySelector("#init_image")
let initImagePreview = document.querySelector("#init_image_preview") let initImagePreview = document.querySelector("#init_image_preview")
@ -81,7 +94,9 @@ let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask") let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview") let maskImagePreview = document.querySelector("#mask_preview")
let applyColorCorrectionField = document.querySelector("#apply_color_correction") let applyColorCorrectionField = document.querySelector("#apply_color_correction")
let strictMaskBorderField = document.querySelector("#strict_mask_border")
let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting") let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting")
let strictMaskBorderSetting = document.querySelector("#strict_mask_border_setting")
let promptStrengthSlider = document.querySelector("#prompt_strength_slider") let promptStrengthSlider = document.querySelector("#prompt_strength_slider")
let promptStrengthField = document.querySelector("#prompt_strength") let promptStrengthField = document.querySelector("#prompt_strength")
let samplerField = document.querySelector("#sampler_name") let samplerField = document.querySelector("#sampler_name")
@ -160,6 +175,7 @@ let imagePreviewContent = document.querySelector("#preview-content")
let undoButton = document.querySelector("#undo") let undoButton = document.querySelector("#undo")
let undoBuffer = [] let undoBuffer = []
const UNDO_LIMIT = 20 const UNDO_LIMIT = 20
const MAX_IMG_UNDO_ENTRIES = 5
let loraModels = [] let loraModels = []
@ -271,24 +287,24 @@ function setServerStatus(event) {
// e : MouseEvent // e : MouseEvent
// prompt : Text to be shown as prompt. Should be a question to which "yes" is a good answer. // prompt : Text to be shown as prompt. Should be a question to which "yes" is a good answer.
// fn : function to be called if the user confirms the dialog or has the shift key pressed // fn : function to be called if the user confirms the dialog or has the shift key pressed
// allowSkip: Allow skipping the dialog using the shift key or the confirm_dangerous_actions setting (default: true)
// //
// If the user had the shift key pressed while clicking, the function fn will be executed. // If the user had the shift key pressed while clicking, the function fn will be executed.
// If the setting "confirm_dangerous_actions" in the system settings is disabled, the function // If the setting "confirm_dangerous_actions" in the system settings is disabled, the function
// fn will be executed. // fn will be executed.
// Otherwise, a confirmation dialog is shown. If the user confirms, the function fn will also // Otherwise, a confirmation dialog is shown. If the user confirms, the function fn will also
// be executed. // be executed.
function shiftOrConfirm(e, prompt, fn) { function shiftOrConfirm(e, prompt, fn, allowSkip = true) {
e.stopPropagation() e.stopPropagation()
if (e.shiftKey || !confirmDangerousActionsField.checked) { let tip = allowSkip
? '<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>'
: ""
if (allowSkip && (e.shiftKey || !confirmDangerousActionsField.checked)) {
fn(e) fn(e)
} else { } else {
confirm( confirm(tip, prompt, () => {
'<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>', fn(e)
prompt, })
() => {
fn(e)
}
)
} }
} }
@ -412,6 +428,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
</div> </div>
<button class="imgPreviewItemClearBtn image_clear_btn"><i class="fa-solid fa-xmark"></i></button> <button class="imgPreviewItemClearBtn image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
<span class="img_bottom_label"></span> <span class="img_bottom_label"></span>
<div class="spinner displayNone"><center>${spinnerPacmanHtml}</center><div class="spinnerStatus"></div></div>
</div> </div>
` `
outputContainer.appendChild(imageItemElem) outputContainer.appendChild(imageItemElem)
@ -488,6 +505,8 @@ function showImages(reqBody, res, outputContainer, livePreview) {
const imageSeedLabel = imageItemElem.querySelector(".imgSeedLabel") const imageSeedLabel = imageItemElem.querySelector(".imgSeedLabel")
imageSeedLabel.innerText = "Seed: " + req.seed imageSeedLabel.innerText = "Seed: " + req.seed
const imageUndoBuffer = []
const imageRedoBuffer = []
let buttons = [ let buttons = [
{ text: "Use as Input", on_click: onUseAsInputClick }, { text: "Use as Input", on_click: onUseAsInputClick },
[ [
@ -505,8 +524,10 @@ function showImages(reqBody, res, outputContainer, livePreview) {
{ text: "Make Similar Images", on_click: onMakeSimilarClick }, { text: "Make Similar Images", on_click: onMakeSimilarClick },
{ text: "Draw another 25 steps", on_click: onContinueDrawingClick }, { text: "Draw another 25 steps", on_click: onContinueDrawingClick },
[ [
{ text: "Upscale", on_click: onUpscaleClick, filter: (req, img) => !req.use_upscale }, { html: '<i class="fa-solid fa-undo"></i> Undo', on_click: onUndoFilter },
{ text: "Fix Faces", on_click: onFixFacesClick, filter: (req, img) => !req.use_face_correction }, { html: '<i class="fa-solid fa-redo"></i> Redo', on_click: onRedoFilter },
{ text: "Upscale", on_click: onUpscaleClick },
{ text: "Fix Faces", on_click: onFixFacesClick },
], ],
{ text: "Use as Thumbnail", on_click: onUseAsThumbnailClick }, { text: "Use as Thumbnail", on_click: onUseAsThumbnailClick },
] ]
@ -516,6 +537,14 @@ function showImages(reqBody, res, outputContainer, livePreview) {
const imgItemInfo = imageItemElem.querySelector(".imgItemInfo") const imgItemInfo = imageItemElem.querySelector(".imgItemInfo")
const img = imageItemElem.querySelector("img") const img = imageItemElem.querySelector("img")
const spinner = imageItemElem.querySelector(".spinner")
const spinnerStatus = imageItemElem.querySelector(".spinnerStatus")
const tools = {
spinner: spinner,
spinnerStatus: spinnerStatus,
undoBuffer: imageUndoBuffer,
redoBuffer: imageRedoBuffer,
}
const createButton = function(btnInfo) { const createButton = function(btnInfo) {
if (Array.isArray(btnInfo)) { if (Array.isArray(btnInfo)) {
const wrapper = document.createElement("div") const wrapper = document.createElement("div")
@ -541,8 +570,16 @@ function showImages(reqBody, res, outputContainer, livePreview) {
if (btnInfo.on_click || !isLabel) { if (btnInfo.on_click || !isLabel) {
newButton.addEventListener("click", function(event) { newButton.addEventListener("click", function(event) {
btnInfo.on_click(req, img, event) btnInfo.on_click.bind(newButton)(req, img, event, tools)
}) })
if (btnInfo.on_click === onUndoFilter) {
tools["undoButton"] = newButton
newButton.classList.add("displayNone")
}
if (btnInfo.on_click === onRedoFilter) {
tools["redoButton"] = newButton
newButton.classList.add("displayNone")
}
} }
if (btnInfo.class !== undefined) { if (btnInfo.class !== undefined) {
@ -671,16 +708,86 @@ function enqueueImageVariationTask(req, img, reqDiff) {
createTask(newTaskRequest) createTask(newTaskRequest)
} }
function onUpscaleClick(req, img) { function applyInlineFilter(filterName, path, filterParams, img, statusText, tools) {
enqueueImageVariationTask(req, img, { const filterReq = {
use_upscale: upscaleModelField.value, image: img.src,
filter: filterName,
model_paths: {},
filter_params: filterParams,
output_format: outputFormatField.value,
output_quality: parseInt(outputQualityField.value),
output_lossless: outputLosslessField.checked,
}
filterReq.model_paths[filterName] = path
tools.spinnerStatus.innerText = statusText
tools.spinner.classList.remove("displayNone")
SD.filter(filterReq, (e) => {
if (e.status === "succeeded") {
let prevImg = img.src
img.src = e.output[0]
tools.spinner.classList.add("displayNone")
if (prevImg.length > 0) {
tools.undoBuffer.push(prevImg)
tools.redoBuffer = []
if (tools.undoBuffer.length > MAX_IMG_UNDO_ENTRIES) {
let n = tools.undoBuffer.length
tools.undoBuffer.splice(0, n - MAX_IMG_UNDO_ENTRIES)
}
tools.undoButton.classList.remove("displayNone")
tools.redoButton.classList.add("displayNone")
}
} else if (e.status == "failed") {
alert("Error running upscale: " + e.detail)
tools.spinner.classList.add("displayNone")
}
}) })
} }
function onFixFacesClick(req, img) { function moveImageBetweenBuffers(img, fromBuffer, toBuffer, fromButton, toButton) {
enqueueImageVariationTask(req, img, { if (fromBuffer.length === 0) {
use_face_correction: gfpganModelField.value, return
}) }
let src = fromBuffer.pop()
if (src.length > 0) {
toBuffer.push(img.src)
img.src = src
}
if (fromBuffer.length === 0) {
fromButton.classList.add("displayNone")
}
if (toBuffer.length > 0) {
toButton.classList.remove("displayNone")
}
}
function onUndoFilter(req, img, e, tools) {
moveImageBetweenBuffers(img, tools.undoBuffer, tools.redoBuffer, tools.undoButton, tools.redoButton)
}
function onRedoFilter(req, img, e, tools) {
moveImageBetweenBuffers(img, tools.redoBuffer, tools.undoBuffer, tools.redoButton, tools.undoButton)
}
function onUpscaleClick(req, img, e, tools) {
let path = upscaleModelField.value
let scale = parseInt(upscaleAmountField.value)
let filterName = path.toLowerCase().includes("realesrgan") ? "realesrgan" : "latent_upscaler"
let statusText = "Upscaling by " + scale + "x using " + filterName
applyInlineFilter(filterName, path, { scale: scale }, img, statusText, tools)
}
function onFixFacesClick(req, img, e, tools) {
let path = gfpganModelField.value
let filterName = path.toLowerCase().includes("gfpgan") ? "gfpgan" : "codeformer"
let statusText = "Fixing faces with " + filterName
applyInlineFilter(filterName, path, {}, img, statusText, tools)
} }
function onContinueDrawingClick(req, img) { function onContinueDrawingClick(req, img) {
@ -924,7 +1031,9 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
<a href="https://www.ibm.com/docs/en/opw/8.2.0?topic=tuning-optional-increasing-paging-file-size-windows-computers" target="_blank">Windows</a> or <a href="https://www.ibm.com/docs/en/opw/8.2.0?topic=tuning-optional-increasing-paging-file-size-windows-computers" target="_blank">Windows</a> or
<a href="https://linuxhint.com/increase-swap-space-linux/" target="_blank">Linux</a>.<br/> <a href="https://linuxhint.com/increase-swap-space-linux/" target="_blank">Linux</a>.<br/>
3. Try restarting your computer.<br/>` 3. Try restarting your computer.<br/>`
} else if (msg.includes("RuntimeError: output with shape [320, 320] doesn't match the broadcast shape")) { } else if (
msg.includes("RuntimeError: output with shape [320, 320] doesn't match the broadcast shape")
) {
msg += `<br/><br/> msg += `<br/><br/>
<b>Reason</b>: You tried to use a LORA that was trained for a different Stable Diffusion model version! <b>Reason</b>: You tried to use a LORA that was trained for a different Stable Diffusion model version!
<br/><br/> <br/><br/>
@ -1298,6 +1407,7 @@ function getCurrentUserRequest() {
// } // }
if (maskSetting.checked) { if (maskSetting.checked) {
newTask.reqBody.mask = imageInpainter.getImg() newTask.reqBody.mask = imageInpainter.getImg()
newTask.reqBody.strict_mask_border = strictMaskBorderField.checked
} }
newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked
if (!testDiffusers.checked) { if (!testDiffusers.checked) {
@ -1338,6 +1448,11 @@ function getCurrentUserRequest() {
newTask.reqBody.lora_alpha = modelStrengths newTask.reqBody.lora_alpha = modelStrengths
} }
} }
if (testDiffusers.checked && document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
// TRT is installed
newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked
}
return newTask return newTask
} }
@ -1998,6 +2113,7 @@ function img2imgLoad() {
} }
initImagePreviewContainer.classList.add("has-image") initImagePreviewContainer.classList.add("has-image")
colorCorrectionSetting.style.display = "" colorCorrectionSetting.style.display = ""
strictMaskBorderSetting.style.display = maskSetting.checked ? "" : "none"
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight) imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
@ -2015,6 +2131,7 @@ function img2imgUnload() {
} }
initImagePreviewContainer.classList.remove("has-image") initImagePreviewContainer.classList.remove("has-image")
colorCorrectionSetting.style.display = "none" colorCorrectionSetting.style.display = "none"
strictMaskBorderSetting.style.display = "none"
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value)) imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
} }
initImagePreview.addEventListener("load", img2imgLoad) initImagePreview.addEventListener("load", img2imgLoad)
@ -2023,6 +2140,9 @@ initImageClearBtn.addEventListener("click", img2imgUnload)
maskSetting.addEventListener("click", function() { maskSetting.addEventListener("click", function() {
onDimensionChange() onDimensionChange()
}) })
maskSetting.addEventListener("change", function() {
strictMaskBorderSetting.style.display = this.checked ? "" : "none"
})
promptsFromFileBtn.addEventListener("click", function() { promptsFromFileBtn.addEventListener("click", function() {
promptsFromFileSelector.click() promptsFromFileSelector.click()
@ -2138,6 +2258,11 @@ resumeBtn.addEventListener("click", function() {
document.body.classList.remove("wait-pause") document.body.classList.remove("wait-pause")
}) })
function onPing(event) {
tunnelUpdate(event)
packagesUpdate(event)
}
function tunnelUpdate(event) { function tunnelUpdate(event) {
if ("cloudflare" in event) { if ("cloudflare" in event) {
document.getElementById("cloudflare-off").classList.add("displayNone") document.getElementById("cloudflare-off").classList.add("displayNone")
@ -2151,6 +2276,23 @@ function tunnelUpdate(event) {
} }
} }
function packagesUpdate(event) {
let trtBtn = document.getElementById("toggle-tensorrt-install")
let trtInstalled = "packages_installed" in event && "tensorrt" in event["packages_installed"]
if ("packages_installing" in event && event["packages_installing"].includes("tensorrt")) {
trtBtn.innerHTML = "Installing.."
trtBtn.disabled = true
} else {
trtBtn.innerHTML = trtInstalled ? "Uninstall" : "Install"
trtBtn.disabled = false
}
if (document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") {
document.querySelector("#enable_trt_config").classList.remove("displayNone")
}
}
document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", async function() { document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", async function() {
let command = "stop" let command = "stop"
if (document.getElementById("toggle-cloudflare-tunnel").innerHTML == "Start") { if (document.getElementById("toggle-cloudflare-tunnel").innerHTML == "Start") {
@ -2170,6 +2312,63 @@ document.getElementById("toggle-cloudflare-tunnel").addEventListener("click", as
console.log(`Cloudflare tunnel ${command} result:`, res) console.log(`Cloudflare tunnel ${command} result:`, res)
}) })
document.getElementById("toggle-tensorrt-install").addEventListener("click", function(e) {
if (this.disabled === true) {
return
}
let command = this.innerHTML.toLowerCase()
let self = this
shiftOrConfirm(
e,
"Are you sure you want to " + command + " TensorRT?",
async function() {
showToast(`TensorRT ${command} started. Please wait.`)
self.disabled = true
if (command === "install") {
self.innerHTML = "Installing.."
} else if (command === "uninstall") {
self.innerHTML = "Uninstalling.."
}
if (command === "installing..") {
alert("Already installing TensorRT!")
return
}
if (command !== "install" && command !== "uninstall") {
return
}
let res = await fetch("/package/tensorrt", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
command: command,
}),
})
res = await res.json()
self.disabled = false
if (res.status === "OK") {
alert("TensorRT " + command + "ed successfully!")
self.innerHTML = command === "install" ? "Uninstall" : "Install"
} else if (res.status_code === 500) {
alert("TensorselfRT failed to " + command + ": " + res.detail)
self.innerHTML = command === "install" ? "Install" : "Uninstall"
}
console.log(`Package ${command} result:`, res)
},
false
)
})
/* Embeddings */ /* Embeddings */
let icl = [] let icl = []
@ -2194,7 +2393,10 @@ function updateEmbeddingsList(filter = "") {
} else { } else {
let subdir = html(m[1], iconlist, prefix + m[0] + "/", filter) let subdir = html(m[1], iconlist, prefix + m[0] + "/", filter)
if (subdir != "") { if (subdir != "") {
folders += `<div class="embedding-category"><h4 class="collapsible">${prefix}${m[0]}</h4><div class="collapsible-content">` + subdir + '</div></div>' folders +=
`<div class="embedding-category"><h4 class="collapsible">${prefix}${m[0]}</h4><div class="collapsible-content">` +
subdir +
"</div></div>"
} }
} }
}) })
@ -2320,7 +2522,6 @@ embeddingsCollapsiblesBtn.addEventListener("click", (e) => {
} }
}) })
if (testDiffusers.checked) { if (testDiffusers.checked) {
document.getElementById("embeddings-container").classList.remove("displayNone") document.getElementById("embeddings-container").classList.remove("displayNone")
} }
@ -2417,3 +2618,172 @@ createLoraEntries()
// } // }
// document.querySelectorAll("input[type=number]").forEach(showSpinnerOnlyOnHover) // document.querySelectorAll("input[type=number]").forEach(showSpinnerOnlyOnHover)
////////////////////////////// Image Size Widget //////////////////////////////////////////
function roundToMultiple(number, n) {
if (n == "") {
n = 1
}
return Math.round(number / n) * n
}
function addImageSizeOption(size) {
let sizes = Object.values(widthField.options).map((o) => o.value)
if (!sizes.includes(String(size))) {
sizes.push(String(size))
sizes.sort((a, b) => Number(a) - Number(b))
let option = document.createElement("option")
option.value = size
option.text = `${size}`
widthField.add(option, sizes.indexOf(String(size)))
heightField.add(option.cloneNode(true), sizes.indexOf(String(size)))
}
}
function setImageWidthHeight(w, h) {
let step = customWidthField.step
w = roundToMultiple(w, step)
h = roundToMultiple(h, step)
addImageSizeOption(w)
addImageSizeOption(h)
widthField.value = w
heightField.value = h
widthField.dispatchEvent(new Event("change"))
heightField.dispatchEvent(new Event("change"))
}
function enlargeImageSize(factor) {
let step = customWidthField.step
let w = roundToMultiple(widthField.value * factor, step)
let h = roundToMultiple(heightField.value * factor, step)
customWidthField.value = w
customHeightField.value = h
}
let recentResolutionsValues = []
;(function() {
///// Init resolutions dropdown
function makeResolutionButtons() {
recentResolutionList.innerHTML = ""
recentResolutionsValues.forEach((el) => {
let button = document.createElement("button")
button.classList.add("tertiaryButton")
button.style.width = "8em"
button.innerHTML = `${el.w}&times;${el.h}`
button.addEventListener("click", () => {
customWidthField.value = el.w
customHeightField.value = el.h
hidePopup()
})
recentResolutionList.appendChild(button)
recentResolutionList.appendChild(document.createElement("br"))
})
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
}
enlarge15Button.addEventListener("click", () => {
enlargeImageSize(1.5)
hidePopup()
})
enlarge2Button.addEventListener("click", () => {
enlargeImageSize(2)
hidePopup()
})
enlarge3Button.addEventListener("click", () => {
enlargeImageSize(3)
hidePopup()
})
customWidthField.addEventListener("change", () => {
let w = customWidthField.value
customWidthField.value = roundToMultiple(w, customWidthField.step)
if (w != customWidthField.value) {
showToast(`Rounded width to the closest multiple of ${customWidthField.step}.`)
}
})
customHeightField.addEventListener("change", () => {
let h = customHeightField.value
customHeightField.value = roundToMultiple(h, customHeightField.step)
if (h != customHeightField.value) {
showToast(`Rounded height to the closest multiple of ${customHeightField.step}.`)
}
})
makeImageBtn.addEventListener("click", () => {
let w = widthField.value
let h = heightField.value
recentResolutionsValues = recentResolutionsValues.filter((el) => el.w != w || el.h != h)
recentResolutionsValues.unshift({ w: w, h: h })
recentResolutionsValues = recentResolutionsValues.slice(0, 8)
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
makeResolutionButtons()
})
let _jsonstring = localStorage.recentResolutionsValues
if (_jsonstring == undefined) {
recentResolutionsValues = [
{ w: 512, h: 512 },
{ w: 640, h: 448 },
{ w: 448, h: 640 },
{ w: 512, h: 768 },
{ w: 768, h: 512 },
{ w: 1024, h: 768 },
{ w: 768, h: 1024 },
]
localStorage.recentResolutionsValues = JSON.stringify(recentResolutionsValues)
} else {
recentResolutionsValues = JSON.parse(localStorage.recentResolutionsValues)
}
makeResolutionButtons()
recentResolutionsValues.forEach((val) => {
addImageSizeOption(val.w)
addImageSizeOption(val.h)
})
function processClick(e) {
if (!recentResolutionsPopup.contains(e.target)) {
hidePopup()
}
}
function showPopup() {
customWidthField.value = widthField.value
customHeightField.value = heightField.value
recentResolutionsPopup.classList.remove("displayNone")
document.addEventListener("click", processClick)
}
function hidePopup() {
recentResolutionsPopup.classList.add("displayNone")
setImageWidthHeight(customWidthField.value, customHeightField.value)
document.removeEventListener("click", processClick)
}
recentResolutionsButton.addEventListener("click", (event) => {
if (recentResolutionsPopup.classList.contains("displayNone")) {
showPopup()
event.stopPropagation()
} else {
hidePopup()
}
})
swapWidthHeightButton.addEventListener("click", (event) => {
let temp = widthField.value
widthField.value = heightField.value
heightField.value = temp
})
})()

View File

@ -16,6 +16,7 @@ var ParameterType = {
*/ */
let parametersTable = document.querySelector("#system-settings-table") let parametersTable = document.querySelector("#system-settings-table")
let networkParametersTable = document.querySelector("#system-settings-network-table") let networkParametersTable = document.querySelector("#system-settings-network-table")
let installExtrasTable = document.querySelector("#system-settings-install-extras-table")
/** /**
* JSDoc style * JSDoc style
@ -240,7 +241,18 @@ var PARAMETERS = [
icon: ["fa-brands", "fa-cloudflare"], icon: ["fa-brands", "fa-cloudflare"],
render: () => '<button id="toggle-cloudflare-tunnel" class="primaryButton">Start</button>', render: () => '<button id="toggle-cloudflare-tunnel" class="primaryButton">Start</button>',
table: networkParametersTable, table: networkParametersTable,
} },
{
id: "nvidia_tensorrt",
type: ParameterType.custom,
label: "NVIDIA TensorRT",
note: `Faster image generation by converting your Stable Diffusion models to the NVIDIA TensorRT format. You can choose the
models to convert. Download size: approximately 2 GB.<br/><br/>
<b>Early access version:</b> support for LoRA is still under development.`,
icon: "fa-angles-up",
render: () => '<button id="toggle-tensorrt-install" class="primaryButton">Install</button>',
table: installExtrasTable,
},
] ]
function getParameterSettingsEntry(id) { function getParameterSettingsEntry(id) {
@ -315,7 +327,7 @@ function initParameters(parameters) {
noteElements.push(noteElement) noteElements.push(noteElement)
} }
if (typeof(parameter.icon) == "string") { if (typeof parameter.icon == "string") {
parameter.icon = [parameter.icon] parameter.icon = [parameter.icon]
} }
const icon = parameter.icon ? [createElement("i", undefined, ["fa", ...parameter.icon])] : [] const icon = parameter.icon ? [createElement("i", undefined, ["fa", ...parameter.icon])] : []
@ -342,7 +354,7 @@ function initParameters(parameters) {
let p = parametersTable let p = parametersTable
if (parameter.table) { if (parameter.table) {
p = parameter.table p = parameter.table
} }
p.appendChild(newrow) p.appendChild(newrow)
parameter.settingsEntry = newrow parameter.settingsEntry = newrow
@ -409,7 +421,7 @@ async function getAppConfig() {
useBetaChannelField.checked = true useBetaChannelField.checked = true
document.querySelector("#updateBranchLabel").innerText = "(beta)" document.querySelector("#updateBranchLabel").innerText = "(beta)"
} else { } else {
getParameterSettingsEntry("test_diffusers").style.display = "none" getParameterSettingsEntry("test_diffusers").classList.add("displayNone")
} }
if (config.ui && config.ui.open_browser_on_start === false) { if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false uiOpenBrowserOnStartField.checked = false
@ -426,11 +438,11 @@ async function getAppConfig() {
if (config.config_on_startup) { if (config.config_on_startup) {
if (config.config_on_startup?.test_diffusers && config.update_branch !== "main") { if (config.config_on_startup?.test_diffusers && config.update_branch !== "main") {
document.body.classList.add("diffusers-enabled-on-startup"); document.body.classList.add("diffusers-enabled-on-startup")
document.body.classList.remove("diffusers-disabled-on-startup"); document.body.classList.remove("diffusers-disabled-on-startup")
} else { } else {
document.body.classList.add("diffusers-disabled-on-startup"); document.body.classList.add("diffusers-disabled-on-startup")
document.body.classList.remove("diffusers-enabled-on-startup"); document.body.classList.remove("diffusers-enabled-on-startup")
} }
} }
@ -441,16 +453,20 @@ async function getAppConfig() {
document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => { document.querySelectorAll("#sampler_name option.diffusers-only").forEach((option) => {
option.style.display = "none" option.style.display = "none"
}) })
customWidthField.step=64
customHeightField.step=64
} else { } else {
document.querySelector("#lora_model_container").style.display = "" document.querySelector("#lora_model_container").style.display = ""
document.querySelector("#tiling_container").style.display = "" document.querySelector("#tiling_container").style.display = ""
document.querySelectorAll("#sampler_name option.k_diffusion-only").forEach((option) => { document.querySelectorAll("#sampler_name option.k_diffusion-only").forEach((option) => {
option.disabled = true option.style.display = "none"
}) })
document.querySelector("#clip_skip_config").classList.remove("displayNone") document.querySelector("#clip_skip_config").classList.remove("displayNone")
document.querySelector("#embeddings-button").classList.remove("displayNone") document.querySelector("#embeddings-button").classList.remove("displayNone")
document.querySelector("#negative-embeddings-button").classList.remove("displayNone") document.querySelector("#negative-embeddings-button").classList.remove("displayNone")
customWidthField.step=8
customHeightField.step=8
} }
console.log("get config status response", config) console.log("get config status response", config)
@ -582,6 +598,23 @@ function setDeviceInfo(devices) {
systemInfoEl.querySelector("#system-info-cpu").innerText = cpu systemInfoEl.querySelector("#system-info-cpu").innerText = cpu
systemInfoEl.querySelector("#system-info-gpus-all").innerHTML = allGPUs.join("</br>") systemInfoEl.querySelector("#system-info-gpus-all").innerHTML = allGPUs.join("</br>")
systemInfoEl.querySelector("#system-info-rendering-devices").innerHTML = activeGPUs.join("</br>") systemInfoEl.querySelector("#system-info-rendering-devices").innerHTML = activeGPUs.join("</br>")
// tensorRT
if (devices.active && testDiffusers.checked && devices.enable_trt === true) {
let nvidiaGPUs = Object.keys(devices.active).filter((d) => {
let gpuName = devices.active[d].name
gpuName = gpuName.toLowerCase()
return (
gpuName.includes("nvidia") ||
gpuName.includes("geforce") ||
gpuName.includes("quadro") ||
gpuName.includes("tesla")
)
})
if (nvidiaGPUs.length > 0) {
document.querySelector("#install-extras-container").classList.remove("displayNone")
}
}
} }
function setHostInfo(hosts) { function setHostInfo(hosts) {
@ -674,7 +707,7 @@ saveSettingsBtn.addEventListener("click", function() {
update_branch: updateBranch, update_branch: updateBranch,
} }
document.querySelectorAll('#system-settings [data-setting-id]').forEach((parameterRow) => { document.querySelectorAll("#system-settings [data-setting-id]").forEach((parameterRow) => {
if (parameterRow.dataset.saveInAppConfig === "true") { if (parameterRow.dataset.saveInAppConfig === "true") {
const parameterElement = const parameterElement =
document.getElementById(parameterRow.dataset.settingId) || document.getElementById(parameterRow.dataset.settingId) ||
@ -713,28 +746,41 @@ saveSettingsBtn.addEventListener("click", function() {
Promise.all([savePromise, asyncDelay(300)]).then(() => saveSettingsBtn.classList.remove("active")) Promise.all([savePromise, asyncDelay(300)]).then(() => saveSettingsBtn.classList.remove("active"))
}) })
listenToNetworkField.addEventListener("change", debounce( ()=>{ listenToNetworkField.addEventListener(
saveSettingsBtn.click() "change",
}, 1000)) debounce(() => {
saveSettingsBtn.click()
}, 1000)
)
listenPortField.addEventListener("change", debounce( ()=>{ listenPortField.addEventListener(
saveSettingsBtn.click() "change",
}, 1000)) debounce(() => {
saveSettingsBtn.click()
}, 1000)
)
let copyCloudflareAddressBtn = document.querySelector("#copy-cloudflare-address") let copyCloudflareAddressBtn = document.querySelector("#copy-cloudflare-address")
let cloudflareAddressField = document.getElementById("cloudflare-address") let cloudflareAddressField = document.getElementById("cloudflare-address")
navigator.permissions.query({ name: "clipboard-write" }).then(function (result) { navigator.permissions.query({ name: "clipboard-write" }).then(function(result) {
if (result.state === "granted") { if (result.state === "granted") {
// you can read from the clipboard // you can read from the clipboard
copyCloudflareAddressBtn.addEventListener("click", (e) => { copyCloudflareAddressBtn.addEventListener("click", (e) => {
navigator.clipboard.writeText(cloudflareAddressField.innerHTML) navigator.clipboard.writeText(cloudflareAddressField.innerHTML)
showToast("Copied server address to clipboard") showToast("Copied server address to clipboard")
}) })
} else { } else {
copyCloudflareAddressBtn.classList.add("displayNone") copyCloudflareAddressBtn.classList.add("displayNone")
} }
}); })
document.addEventListener("system_info_update", (e) => setDeviceInfo(e.detail)) document.addEventListener("system_info_update", (e) => setDeviceInfo(e.detail))
useBetaChannelField.addEventListener("change", (e) => {
if (e.target.checked) {
getParameterSettingsEntry("test_diffusers").classList.remove("displayNone")
} else {
getParameterSettingsEntry("test_diffusers").classList.add("displayNone")
}
})

View File

@ -109,8 +109,10 @@
imageObj.onload = function() { imageObj.onload = function() {
// Calculate the maximum cropped dimensions // Calculate the maximum cropped dimensions
const maxCroppedWidth = Math.floor(this.width / 64) * 64; const step = customWidthField.step
const maxCroppedHeight = Math.floor(this.height / 64) * 64;
const maxCroppedWidth = Math.floor(this.width / step) * step;
const maxCroppedHeight = Math.floor(this.height / step) * step;
canvas.width = maxCroppedWidth; canvas.width = maxCroppedWidth;
canvas.height = maxCroppedHeight; canvas.height = maxCroppedHeight;