diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index bdecc109..1ee5ce9d 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -5,7 +5,7 @@ import traceback from typing import Union from easydiffusion import app -from easydiffusion.types import TaskData +from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db @@ -57,7 +57,9 @@ def init(): def load_default_models(context: Context): - set_vram_optimizations(context) + from easydiffusion import runtime + + runtime.set_vram_optimizations(context) config = app.getConfig() context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings") @@ -138,43 +140,32 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?") -def reload_models_if_necessary(context: Context, task_data: TaskData): - face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else "" - upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else "" - - model_paths_in_req = { - "stable-diffusion": task_data.use_stable_diffusion_model, - "vae": task_data.use_vae_model, - "hypernetwork": task_data.use_hypernetwork_model, - "codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None, - "gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None, - "realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None, - "latent_upscaler": True if "latent_upscaler" in upscale_lower else None, - "nsfw_checker": True if task_data.block_nsfw else None, - "lora": task_data.use_lora_model, - } +def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []): models_to_reload = { model_type: path - for model_type, path in model_paths_in_req.items() + for model_type, path in models_data.model_paths.items() if context.model_paths.get(model_type) != path } - if task_data.codeformer_upscale_faces: + if models_data.model_paths.get("codeformer"): if "realesrgan" not in models_to_reload and "realesrgan" not in context.models: default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None: del models_to_reload["realesrgan"] # don't unload realesrgan - if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD - models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"] + for model_type in models_to_force_reload: + if model_type not in models_data.model_paths: + continue + models_to_reload[model_type] = models_data.model_paths[model_type] for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req action_fn = unload_model if context.model_paths[model_type] is None else load_model + extra_params = models_data.model_params.get(model_type, {}) try: - action_fn(context, model_type, scan_model=False) # we've scanned them already + action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already if model_type in context.model_load_errors: del context.model_load_errors[model_type] except Exception as e: @@ -183,24 +174,15 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks -def resolve_model_paths(task_data: TaskData): - task_data.use_stable_diffusion_model = resolve_model_to_use( - task_data.use_stable_diffusion_model, model_type="stable-diffusion" - ) - task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") - task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") - task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora") - - if task_data.use_face_correction: - if "gfpgan" in task_data.use_face_correction.lower(): - model_type = "gfpgan" - elif "codeformer" in task_data.use_face_correction.lower(): - model_type = "codeformer" +def resolve_model_paths(models_data: ModelsData): + model_paths = models_data.model_paths + for model_type in model_paths: + if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths + continue + if model_type == "codeformer": download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") - task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type) - if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): - task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan") + model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type) def fail_if_models_did_not_load(context: Context): @@ -235,17 +217,6 @@ def download_if_necessary(model_type: str, file_name: str, model_id: str): download_model(model_type, model_id, download_base_dir=app.MODELS_DIR) -def set_vram_optimizations(context: Context): - config = app.getConfig() - vram_usage_level = config.get("vram_usage_level", "balanced") - - if vram_usage_level != context.vram_usage_level: - context.vram_usage_level = vram_usage_level - return True - - return False - - def migrate_legacy_model_location(): 'Move the models inside the legacy "stable-diffusion" folder, to their respective folders' @@ -266,16 +237,6 @@ def any_model_exists(model_type: str) -> bool: return False -def set_clip_skip(context: Context, task_data: TaskData): - clip_skip = task_data.clip_skip - - if clip_skip != context.clip_skip: - context.clip_skip = clip_skip - return True - - return False - - def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) diff --git a/ui/easydiffusion/runtime.py b/ui/easydiffusion/runtime.py new file mode 100644 index 00000000..4098ee8e --- /dev/null +++ b/ui/easydiffusion/runtime.py @@ -0,0 +1,53 @@ +""" +A runtime that runs on a specific device (in a thread). + +It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context. + +This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function. +""" + +from easydiffusion import device_manager +from easydiffusion.utils import log +from sdkit import Context +from sdkit.utils import get_device_usage + +context = Context() # thread-local +""" +runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc +""" + + +def init(device): + """ + Initializes the fields that will be bound to this runtime's context, and sets the current torch device + """ + context.stop_processing = False + context.temp_images = {} + context.partial_x_samples = None + context.model_load_errors = {} + context.enable_codeformer = True + + from easydiffusion import app + + app_config = app.getConfig() + context.test_diffusers = ( + app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main" + ) + + log.info("Device usage during initialization:") + get_device_usage(device, log_info=True, process_usage_only=False) + + device_manager.device_init(context, device) + + +def set_vram_optimizations(context: Context): + from easydiffusion import app + + config = app.getConfig() + vram_usage_level = config.get("vram_usage_level", "balanced") + + if vram_usage_level != context.vram_usage_level: + context.vram_usage_level = vram_usage_level + return True + + return False diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index df788b0c..0f1890c3 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -9,7 +9,16 @@ import traceback from typing import List, Union from easydiffusion import app, model_manager, task_manager -from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData +from easydiffusion.tasks import RenderTask, FilterTask +from easydiffusion.types import ( + GenerateImageRequest, + FilterImageRequest, + MergeRequest, + TaskData, + ModelsData, + OutputFormatData, + convert_legacy_render_req_to_new, +) from easydiffusion.utils import log from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles @@ -97,6 +106,10 @@ def init(): def render(req: dict): return render_internal(req) + @server_api.post("/filter") + def render(req: dict): + return filter_internal(req) + @server_api.post("/model/merge") def model_merge(req: dict): print(req) @@ -228,9 +241,13 @@ def ping_internal(session_id: str = None): def render_internal(req: dict): try: + req = convert_legacy_render_req_to_new(req) + # separate out the request data into rendering and task-specific data render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) task_data: TaskData = TaskData.parse_obj(req) + models_data: ModelsData = ModelsData.parse_obj(req) + output_format: OutputFormatData = OutputFormatData.parse_obj(req) # Overwrite user specified save path config = app.getConfig() @@ -240,28 +257,53 @@ def render_internal(req: dict): render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision app.save_to_config( - task_data.use_stable_diffusion_model, - task_data.use_vae_model, - task_data.use_hypernetwork_model, + models_data.model_paths.get("stable-diffusion"), + models_data.model_paths.get("vae"), + models_data.model_paths.get("hypernetwork"), task_data.vram_usage_level, ) # enqueue the task - new_task = task_manager.render(render_req, task_data) + task = RenderTask(render_req, task_data, models_data, output_format) + return enqueue_task(task) + except HTTPException as e: + raise e + except Exception as e: + log.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + +def filter_internal(req: dict): + try: + session_id = req.get("session_id", "session") + filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req) + models_data: ModelsData = ModelsData.parse_obj(req) + output_format: OutputFormatData = OutputFormatData.parse_obj(req) + + # enqueue the task + task = FilterTask(filter_req, session_id, models_data, output_format) + return enqueue_task(task) + except HTTPException as e: + raise e + except Exception as e: + log.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + + +def enqueue_task(task): + try: + task_manager.enqueue_task(task) response = { "status": str(task_manager.current_state), "queue": len(task_manager.tasks_queue), - "stream": f"/image/stream/{id(new_task)}", - "task": id(new_task), + "stream": f"/image/stream/{task.id}", + "task": task.id, } return JSONResponse(response, headers=NOCACHE_HEADERS) except ChildProcessError as e: # Render thread is dead raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable - except Exception as e: - log.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) def model_merge_internal(req: dict): diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index a91cd9c6..27b53b6f 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -17,7 +17,7 @@ from typing import Any, Hashable import torch from easydiffusion import device_manager -from easydiffusion.types import GenerateImageRequest, TaskData +from easydiffusion.tasks import Task from easydiffusion.utils import log from sdkit.utils import gc @@ -27,6 +27,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. +MAX_OVERLOAD_ALLOWED_RATIO = 2 # i.e. 2x pending tasks compared to the number of render threads class SymbolClass(type): # Print nicely formatted Symbol names. @@ -58,46 +59,6 @@ class ServerStates: pass -class RenderTask: # Task with output queue and completion lock. - def __init__(self, req: GenerateImageRequest, task_data: TaskData): - task_data.request_id = id(self) - self.render_request: GenerateImageRequest = req # Initial Request - self.task_data: TaskData = task_data - self.response: Any = None # Copy of the last reponse - self.render_device = None # Select the task affinity. (Not used to change active devices). - self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) - self.error: Exception = None - self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed - self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments - - async def read_buffer_generator(self): - try: - while not self.buffer_queue.empty(): - res = self.buffer_queue.get(block=False) - self.buffer_queue.task_done() - yield res - except queue.Empty as e: - yield - - @property - def status(self): - if self.lock.locked(): - return "running" - if isinstance(self.error, StopAsyncIteration): - return "stopped" - if self.error: - return "error" - if not self.buffer_queue.empty(): - return "buffer" - if self.response: - return "completed" - return "pending" - - @property - def is_pending(self): - return bool(not self.response and not self.error) - - # Temporary cache to allow to query tasks results for a short time after they are completed. class DataCache: def __init__(self): @@ -123,8 +84,8 @@ class DataCache: # Remove Items for key in to_delete: (_, val) = self._base[key] - if isinstance(val, RenderTask): - log.debug(f"RenderTask {key} expired. Data removed.") + if isinstance(val, Task): + log.debug(f"Task {key} expired. Data removed.") elif isinstance(val, SessionState): log.debug(f"Session {key} expired. Data removed.") else: @@ -220,8 +181,8 @@ class SessionState: tasks.append(task) return tasks - def put(self, task, ttl=TASK_TTL): - task_id = id(task) + def put(self, task: Task, ttl=TASK_TTL): + task_id = task.id self._tasks_ids.append(task_id) if not task_cache.put(task_id, task, ttl): return False @@ -230,11 +191,16 @@ class SessionState: return True +def keep_task_alive(task: Task): + task_cache.keep(task.id, TASK_TTL) + session_cache.keep(task.session_id, TASK_TTL) + + def thread_get_next_task(): - from easydiffusion import renderer + from easydiffusion import runtime if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.") + log.warn(f"Render thread on device: {runtime.context.device} failed to acquire manager lock.") return None if len(tasks_queue) <= 0: manager_lock.release() @@ -242,7 +208,7 @@ def thread_get_next_task(): task = None try: # Select a render task. for queued_task in tasks_queue: - if queued_task.render_device and renderer.context.device != queued_task.render_device: + if queued_task.render_device and runtime.context.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. @@ -251,7 +217,7 @@ def thread_get_next_task(): queued_task.error = Exception(queued_task.render_device + " is not currently active.") task = queued_task break - if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1: + if not queued_task.render_device and runtime.context.device == "cpu" and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -266,19 +232,19 @@ def thread_get_next_task(): def thread_render(device): global current_state, current_state_error - from easydiffusion import model_manager, renderer + from easydiffusion import model_manager, runtime try: - renderer.init(device) + runtime.init(device) weak_thread_data[threading.current_thread()] = { - "device": renderer.context.device, - "device_name": renderer.context.device_name, + "device": runtime.context.device, + "device_name": runtime.context.device_name, "alive": True, } current_state = ServerStates.LoadingModel - model_manager.load_default_models(renderer.context) + model_manager.load_default_models(runtime.context) current_state = ServerStates.Online except Exception as e: @@ -290,8 +256,8 @@ def thread_render(device): session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]["alive"]: - log.info(f"Shutting down thread for device {renderer.context.device}") - model_manager.unload_all(renderer.context) + log.info(f"Shutting down thread for device {runtime.context.device}") + model_manager.unload_all(runtime.context) return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -311,62 +277,31 @@ def thread_render(device): task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}") + log.info(f"Session {task.session_id} starting task {task.id} on {runtime.context.device_name}") if not task.lock.acquire(blocking=False): raise Exception("Got locked task from queue.") try: + task.run() - def step_callback(): - global current_state_error - - task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.task_data.session_id, TASK_TTL) - - if ( - isinstance(current_state_error, SystemExit) - or isinstance(current_state_error, StopAsyncIteration) - or isinstance(task.error, StopAsyncIteration) - ): - renderer.context.stop_processing = True - if isinstance(current_state_error, StopAsyncIteration): - task.error = current_state_error - current_state_error = None - log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}") - - current_state = ServerStates.LoadingModel - model_manager.resolve_model_paths(task.task_data) - model_manager.reload_models_if_necessary(renderer.context, task.task_data) - model_manager.fail_if_models_did_not_load(renderer.context) - - current_state = ServerStates.Rendering - task.response = renderer.make_images( - task.render_request, - task.task_data, - task.buffer_queue, - task.temp_images, - step_callback, - ) # Before looping back to the generator, mark cache as still alive. - task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.task_data.session_id, TASK_TTL) + keep_task_alive(task) except Exception as e: task.error = str(e) task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) finally: - gc(renderer.context) + gc(runtime.context) task.lock.release() - task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.task_data.session_id, TASK_TTL) + + keep_task_alive(task) + if isinstance(task.error, StopAsyncIteration): - log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!") + log.info(f"Session {task.session_id} task {task.id} cancelled!") elif task.error is not None: - log.info(f"Session {task.task_data.session_id} task {id(task)} failed!") + log.info(f"Session {task.session_id} task {task.id} failed!") else: - log.info( - f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}." - ) + log.info(f"Session {task.session_id} task {task.id} completed by {runtime.context.device_name}.") current_state = ServerStates.Online @@ -548,28 +483,27 @@ def shutdown_event(): # Signal render thread to close on shutdown current_state_error = SystemExit("Application shutting down.") -def render(render_req: GenerateImageRequest, task_data: TaskData): +def enqueue_task(task: Task): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead raise ChildProcessError("Rendering thread has died.") # Alive, check if task in cache - session = get_cached_session(task_data.session_id, update_ttl=True) + session = get_cached_session(task.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) - if current_thread_count < len(pending_tasks): + if len(pending_tasks) > current_thread_count * MAX_OVERLOAD_ALLOWED_RATIO: raise ConnectionRefusedError( - f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}." + f"Session {task.session_id} already has {len(pending_tasks)} pending tasks, with {current_thread_count} workers." ) - new_task = RenderTask(render_req, task_data) - if session.put(new_task, TASK_TTL): + if session.put(task, TASK_TTL): # Use twice the normal timeout for adding user requests. # Tries to force session.put to fail before tasks_queue.put would. if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2): try: - tasks_queue.append(new_task) + tasks_queue.append(task) idle_event.set() - return new_task + return task finally: manager_lock.release() raise RuntimeError("Failed to add task to cache.") diff --git a/ui/easydiffusion/tasks/__init__.py b/ui/easydiffusion/tasks/__init__.py new file mode 100644 index 00000000..1d295da8 --- /dev/null +++ b/ui/easydiffusion/tasks/__init__.py @@ -0,0 +1,3 @@ +from .task import Task +from .render_images import RenderTask +from .filter_images import FilterTask diff --git a/ui/easydiffusion/tasks/filter_images.py b/ui/easydiffusion/tasks/filter_images.py new file mode 100644 index 00000000..c4e674d7 --- /dev/null +++ b/ui/easydiffusion/tasks/filter_images.py @@ -0,0 +1,110 @@ +import json +import pprint + +from sdkit.filter import apply_filters +from sdkit.models import load_model +from sdkit.utils import img_to_base64_str, log + +from easydiffusion import model_manager, runtime +from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData + +from .task import Task + + +class FilterTask(Task): + "For applying filters to input images" + + def __init__( + self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData + ): + super().__init__(session_id) + + self.request = req + self.models_data = models_data + self.output_format = output_format + + # convert to multi-filter format, if necessary + if isinstance(req.filter, str): + req.filter_params = {req.filter: req.filter_params} + req.filter = [req.filter] + + if not isinstance(req.image, list): + req.image = [req.image] + + def run(self): + "Runs the image filtering task on the assigned thread" + + context = runtime.context + + model_manager.resolve_model_paths(self.models_data) + model_manager.reload_models_if_necessary(context, self.models_data) + model_manager.fail_if_models_did_not_load(context) + + print_task_info(self.request, self.models_data, self.output_format) + + images = filter_images(context, self.request.image, self.request.filter, self.request.filter_params) + + output_format = self.output_format + images = [ + img_to_base64_str( + img, output_format.output_format, output_format.output_quality, output_format.output_lossless + ) + for img in images + ] + + res = FilterImageResponse(self.request, self.models_data, images=images) + res = res.json() + self.buffer_queue.put(json.dumps(res)) + log.info("Filter task completed") + + self.response = res + + +def filter_images(context, images, filters, filter_params={}): + filters = filters if isinstance(filters, list) else [filters] + + for filter_name in filters: + params = filter_params.get(filter_name, {}) + + previous_state = before_filter(context, filter_name, params) + + try: + images = apply_filters(context, filter_name, images, **params) + finally: + after_filter(context, filter_name, params, previous_state) + + return images + + +def before_filter(context, filter_name, filter_params): + if filter_name == "codeformer": + from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use + + default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] + prev_realesrgan_path = None + + upscale_faces = filter_params.get("upscale_faces", False) + if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: + prev_realesrgan_path = context.model_paths.get("realesrgan") + context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") + load_model(context, "realesrgan") + + return prev_realesrgan_path + + +def after_filter(context, filter_name, filter_params, previous_state): + if filter_name == "codeformer": + prev_realesrgan_path = previous_state + if prev_realesrgan_path: + context.model_paths["realesrgan"] = prev_realesrgan_path + load_model(context, "realesrgan") + + +def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData): + req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[") + models_data = pprint.pformat(models_data.dict()).replace("[", "\[") + output_format = pprint.pformat(output_format.dict()).replace("[", "\[") + + log.info(f"request: {req_str}") + log.info(f"models data: {models_data}") + log.info(f"output format: {output_format}") diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/tasks/render_images.py similarity index 54% rename from ui/easydiffusion/renderer.py rename to ui/easydiffusion/tasks/render_images.py index a57dfc6c..42bcc76a 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -3,70 +3,109 @@ import pprint import queue import time -from easydiffusion import device_manager -from easydiffusion.types import GenerateImageRequest +from easydiffusion import model_manager, runtime +from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData from easydiffusion.types import Image as ResponseImage -from easydiffusion.types import Response, TaskData, UserInitiatedStop -from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use +from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop from easydiffusion.utils import get_printable_request, log, save_images_to_disk -from sdkit import Context -from sdkit.filter import apply_filters from sdkit.generate import generate_images -from sdkit.models import load_model from sdkit.utils import ( diffusers_latent_samples_to_images, gc, img_to_base64_str, img_to_buffer, latent_samples_to_images, - get_device_usage, ) -context = Context() # thread-local -""" -runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc -""" +from .task import Task +from .filter_images import filter_images -def init(device): - """ - Initializes the fields that will be bound to this runtime's context, and sets the current torch device - """ - context.stop_processing = False - context.temp_images = {} - context.partial_x_samples = None - context.model_load_errors = {} - context.enable_codeformer = True +class RenderTask(Task): + "For image generation" - from easydiffusion import app + def __init__( + self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData + ): + super().__init__(task_data.session_id) - app_config = app.getConfig() - context.test_diffusers = ( - app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main" - ) + task_data.request_id = self.id + self.render_request: GenerateImageRequest = req # Initial Request + self.task_data: TaskData = task_data + self.models_data = models_data + self.output_format = output_format + self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) - log.info("Device usage during initialization:") - get_device_usage(device, log_info=True, process_usage_only=False) + def run(self): + "Runs the image generation task on the assigned thread" - device_manager.device_init(context, device) + from easydiffusion import task_manager + + context = runtime.context + + def step_callback(): + task_manager.keep_task_alive(self) + task_manager.current_state = task_manager.ServerStates.Rendering + + if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance( + self.error, StopAsyncIteration + ): + context.stop_processing = True + if isinstance(task_manager.current_state_error, StopAsyncIteration): + self.error = task_manager.current_state_error + task_manager.current_state_error = None + log.info(f"Session {self.session_id} sent cancel signal for task {self.id}") + + task_manager.current_state = task_manager.ServerStates.LoadingModel + model_manager.resolve_model_paths(self.models_data) + + models_to_force_reload = [] + if runtime.set_vram_optimizations(context) or self.has_clip_skip_changed(context): + models_to_force_reload.append("stable-diffusion") + + model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload) + model_manager.fail_if_models_did_not_load(context) + + task_manager.current_state = task_manager.ServerStates.Rendering + self.response = make_images( + context, + self.render_request, + self.task_data, + self.models_data, + self.output_format, + self.buffer_queue, + self.temp_images, + step_callback, + ) + + def has_clip_skip_changed(self, context): + if not context.test_diffusers: + return False + + model = context.models["stable-diffusion"] + new_clip_skip = self.models_data.model_params.get("stable-diffusion", {}).get("clip_skip", False) + return model["clip_skip"] != new_clip_skip def make_images( + context, req: GenerateImageRequest, task_data: TaskData, + models_data: ModelsData, + output_format: OutputFormatData, data_queue: queue.Queue, task_temp_images: list, step_callback, ): context.stop_processing = False - print_task_info(req, task_data) + print_task_info(req, task_data, models_data, output_format) - images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + images, seeds = make_images_internal( + context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback + ) - res = Response( - req, - task_data, - images=construct_response(images, seeds, task_data, base_seed=req.seed), + res = GenerateImageResponse( + req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format) ) res = res.json() data_queue.put(json.dumps(res)) @@ -75,21 +114,32 @@ def make_images( return res -def print_task_info(req: GenerateImageRequest, task_data: TaskData): - req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[") +def print_task_info( + req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData +): + req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[") + models_data = pprint.pformat(models_data.dict()).replace("[", "\[") + output_format = pprint.pformat(output_format.dict()).replace("[", "\[") + log.info(f"request: {req_str}") log.info(f"task data: {task_str}") + # log.info(f"models data: {models_data}") + log.info(f"output format: {output_format}") def make_images_internal( + context, req: GenerateImageRequest, task_data: TaskData, + models_data: ModelsData, + output_format: OutputFormatData, data_queue: queue.Queue, task_temp_images: list, step_callback, ): images, user_stopped = generate_images_internal( + context, req, task_data, data_queue, @@ -98,11 +148,14 @@ def make_images_internal( task_data.stream_image_progress, task_data.stream_image_progress_interval, ) + gc(context) - filtered_images = filter_images(req, task_data, images, user_stopped) + + filters, filter_params = task_data.filters, task_data.filter_params + filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images if task_data.save_to_disk_path is not None: - save_images_to_disk(images, filtered_images, req, task_data) + save_images_to_disk(images, filtered_images, req, task_data, output_format) seeds = [*range(req.seed, req.seed + len(images))] if task_data.show_only_filtered_image or filtered_images is images: @@ -112,6 +165,7 @@ def make_images_internal( def generate_images_internal( + context, req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, @@ -123,6 +177,7 @@ def generate_images_internal( context.temp_images.clear() callback = make_step_callback( + context, req, task_data, data_queue, @@ -155,65 +210,14 @@ def generate_images_internal( return images, user_stopped -def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped): - if user_stopped: - return images - - if task_data.block_nsfw: - images = apply_filters(context, "nsfw_checker", images) - - if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower(): - default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] - prev_realesrgan_path = None - if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]: - prev_realesrgan_path = context.model_paths["realesrgan"] - context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") - load_model(context, "realesrgan") - - try: - images = apply_filters( - context, - "codeformer", - images, - upscale_faces=task_data.codeformer_upscale_faces, - codeformer_fidelity=task_data.codeformer_fidelity, - ) - finally: - if prev_realesrgan_path: - context.model_paths["realesrgan"] = prev_realesrgan_path - load_model(context, "realesrgan") - elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): - images = apply_filters(context, "gfpgan", images) - - if task_data.use_upscale: - if "realesrgan" in task_data.use_upscale.lower(): - images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount) - elif task_data.use_upscale == "latent_upscaler": - images = apply_filters( - context, - "latent_upscaler", - images, - scale=task_data.upscale_amount, - latent_upscaler_options={ - "prompt": req.prompt, - "negative_prompt": req.negative_prompt, - "seed": req.seed, - "num_inference_steps": task_data.latent_upscaler_steps, - "guidance_scale": 0, - }, - ) - - return images - - -def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): +def construct_response(images: list, seeds: list, output_format: OutputFormatData): return [ ResponseImage( data=img_to_base64_str( img, - task_data.output_format, - task_data.output_quality, - task_data.output_lossless, + output_format.output_format, + output_format.output_quality, + output_format.output_lossless, ), seed=seed, ) @@ -222,6 +226,7 @@ def construct_response(images: list, seeds: list, task_data: TaskData, base_seed def make_step_callback( + context, req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, @@ -242,7 +247,7 @@ def make_step_callback( images = latent_samples_to_images(context, x_samples) if task_data.block_nsfw: - images = apply_filters(context, "nsfw_checker", images) + images = filter_images(context, images, "nsfw_checker") for i, img in enumerate(images): buf = img_to_buffer(img, output_format="JPEG") diff --git a/ui/easydiffusion/tasks/task.py b/ui/easydiffusion/tasks/task.py new file mode 100644 index 00000000..4454efe6 --- /dev/null +++ b/ui/easydiffusion/tasks/task.py @@ -0,0 +1,47 @@ +from threading import Lock +from queue import Queue, Empty as EmptyQueueException +from typing import Any + + +class Task: + "Task with output queue and completion lock" + + def __init__(self, session_id): + self.id = id(self) + self.session_id = session_id + self.render_device = None # Select the task affinity. (Not used to change active devices). + self.error: Exception = None + self.lock: Lock = Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: Queue = Queue() # Queue of JSON string segments + self.response: Any = None # Copy of the last reponse + + async def read_buffer_generator(self): + try: + while not self.buffer_queue.empty(): + res = self.buffer_queue.get(block=False) + self.buffer_queue.task_done() + yield res + except EmptyQueueException as e: + yield + + @property + def status(self): + if self.lock.locked(): + return "running" + if isinstance(self.error, StopAsyncIteration): + return "stopped" + if self.error: + return "error" + if not self.buffer_queue.empty(): + return "buffer" + if self.response: + return "completed" + return "pending" + + @property + def is_pending(self): + return bool(not self.response and not self.error) + + def run(self): + "Override this to implement the task's behavior" + pass diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index a9e49a24..b5f6b21a 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -1,4 +1,4 @@ -from typing import Any, List, Union +from typing import Any, List, Dict, Union from pydantic import BaseModel @@ -17,6 +17,8 @@ class GenerateImageRequest(BaseModel): init_image: Any = None init_image_mask: Any = None + control_image: Any = None + control_alpha: Union[float, List[float]] = None prompt_strength: float = 0.8 preserve_init_image_color_profile = False @@ -26,6 +28,35 @@ class GenerateImageRequest(BaseModel): tiling: str = "none" # "none", "x", "y", "xy" +class FilterImageRequest(BaseModel): + image: Any = None + filter: Union[str, List[str]] = None + filter_params: dict = {} + + +class ModelsData(BaseModel): + """ + Contains the information related to the models involved in a request. + + - To load a model: set the relative path(s) to the model in `model_paths`. No effect if already loaded. + - To unload a model: set the model to `None` in `model_paths`. No effect if already unloaded. + + Models that aren't present in `model_paths` will not be changed. + """ + + model_paths: Dict[str, Union[str, None, List[str]]] = None + "model_type to string path, or list of string paths" + + model_params: Dict[str, Dict[str, Any]] = {} + "model_type to dict of parameters" + + +class OutputFormatData(BaseModel): + output_format: str = "jpeg" # or "png" or "webp" + output_quality: int = 75 + output_lossless: bool = False + + class TaskData(BaseModel): request_id: str = None session_id: str = "session" @@ -40,12 +71,12 @@ class TaskData(BaseModel): use_vae_model: Union[str, List[str]] = None use_hypernetwork_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None + use_controlnet_model: Union[str, List[str]] = None + filters: List[str] = [] + filter_params: Dict[str, Dict[str, Any]] = {} show_only_filtered_image: bool = False block_nsfw: bool = False - output_format: str = "jpeg" # or "png" or "webp" - output_quality: int = 75 - output_lossless: bool = False metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False stream_image_progress_interval: int = 5 @@ -80,24 +111,38 @@ class Image: } -class Response: +class GenerateImageResponse: render_request: GenerateImageRequest task_data: TaskData + models_data: ModelsData images: list - def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list): + def __init__( + self, + render_request: GenerateImageRequest, + task_data: TaskData, + models_data: ModelsData, + output_format: OutputFormatData, + images: list, + ): self.render_request = render_request self.task_data = task_data + self.models_data = models_data + self.output_format = output_format self.images = images def json(self): del self.render_request.init_image del self.render_request.init_image_mask + task_data = self.task_data.dict() + task_data.update(self.output_format.dict()) + res = { "status": "succeeded", "render_request": self.render_request.dict(), - "task_data": self.task_data.dict(), + "task_data": task_data, + # "models_data": self.models_data.dict(), # haven't migrated the UI to the new format (yet) "output": [], } @@ -107,5 +152,102 @@ class Response: return res +class FilterImageResponse: + request: FilterImageRequest + models_data: ModelsData + images: list + + def __init__(self, request: FilterImageRequest, models_data: ModelsData, images: list): + self.request = request + self.models_data = models_data + self.images = images + + def json(self): + del self.request.image + + res = { + "status": "succeeded", + "request": self.request.dict(), + "models_data": self.models_data.dict(), + "output": [], + } + + for image in self.images: + res["output"].append(image) + + return res + + class UserInitiatedStop(Exception): pass + + +def convert_legacy_render_req_to_new(old_req: dict): + new_req = dict(old_req) + + # new keys + model_paths = new_req["model_paths"] = {} + model_params = new_req["model_params"] = {} + filters = new_req["filters"] = [] + filter_params = new_req["filter_params"] = {} + + # move the model info + model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model") + model_paths["vae"] = old_req.get("use_vae_model") + model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model") + model_paths["lora"] = old_req.get("use_lora_model") + model_paths["controlnet"] = old_req.get("use_controlnet_model") + + model_paths["gfpgan"] = old_req.get("use_face_correction", "") + model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None + + model_paths["codeformer"] = old_req.get("use_face_correction", "") + model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None + + model_paths["realesrgan"] = old_req.get("use_upscale", "") + model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None + + model_paths["latent_upscaler"] = old_req.get("use_upscale", "") + model_paths["latent_upscaler"] = ( + model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None + ) + + if old_req.get("block_nsfw"): + model_paths["nsfw_checker"] = "nsfw_checker" + + # move the model params + if model_paths["stable-diffusion"]: + model_params["stable-diffusion"] = {"clip_skip": bool(old_req["clip_skip"])} + + # move the filter params + if model_paths["realesrgan"]: + filter_params["realesrgan"] = {"scale": int(old_req["upscale_amount"])} + if model_paths["latent_upscaler"]: + filter_params["latent_upscaler"] = { + "prompt": old_req["prompt"], + "negative_prompt": old_req.get("negative_prompt"), + "seed": int(old_req.get("seed", 42)), + "num_inference_steps": int(old_req.get("latent_upscaler_steps", 10)), + "guidance_scale": 0, + } + if model_paths["codeformer"]: + filter_params["codeformer"] = { + "upscale_faces": bool(old_req["codeformer_upscale_faces"]), + "codeformer_fidelity": float(old_req["codeformer_fidelity"]), + } + + # set the filters + if old_req.get("block_nsfw"): + filters.append("nsfw_checker") + + if model_paths["codeformer"]: + filters.append("codeformer") + elif model_paths["gfpgan"]: + filters.append("gfpgan") + + if model_paths["realesrgan"]: + filters.append("realesrgan") + elif model_paths["latent_upscaler"]: + filters.append("latent_upscaler") + + return new_req diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index f27e84de..49743554 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -7,7 +7,7 @@ from datetime import datetime from functools import reduce from easydiffusion import app -from easydiffusion.types import GenerateImageRequest, TaskData +from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData from numpy import base_repr from sdkit.utils import save_dicts, save_images @@ -114,12 +114,14 @@ def format_file_name( return filename_regex.sub("_", format) -def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): +def save_images_to_disk( + images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData +): now = time.time() app_config = app.getConfig() folder_format = app_config.get("folder_format", "$id") save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) - metadata_entries = get_metadata_entries_for_request(req, task_data) + metadata_entries = get_metadata_entries_for_request(req, task_data, output_format) file_number = calculate_img_number(save_dir_path, task_data) make_filename = make_filename_callback( app_config.get("filename_format", "$p_$tsb64"), @@ -134,9 +136,9 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR filtered_images, save_dir_path, file_name=make_filename, - output_format=task_data.output_format, - output_quality=task_data.output_quality, - output_lossless=task_data.output_lossless, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, ) if task_data.metadata_output_format: for metadata_output_format in task_data.metadata_output_format.split(","): @@ -146,7 +148,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR save_dir_path, file_name=make_filename, output_format=metadata_output_format, - file_format=task_data.output_format, + file_format=output_format.output_format, ) else: make_filter_filename = make_filename_callback( @@ -162,17 +164,17 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR images, save_dir_path, file_name=make_filename, - output_format=task_data.output_format, - output_quality=task_data.output_quality, - output_lossless=task_data.output_lossless, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, ) save_images( filtered_images, save_dir_path, file_name=make_filter_filename, - output_format=task_data.output_format, - output_quality=task_data.output_quality, - output_lossless=task_data.output_lossless, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, ) if task_data.metadata_output_format: for metadata_output_format in task_data.metadata_output_format.split(","): @@ -181,20 +183,21 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR metadata_entries, save_dir_path, file_name=make_filter_filename, - output_format=task_data.metadata_output_format, - file_format=task_data.output_format, + output_format=metadata_output_format, + file_format=output_format.output_format, ) -def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): - metadata = get_printable_request(req, task_data) +def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData): + metadata = get_printable_request(req, task_data, output_format) # if text, format it in the text format expected by the UI is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",") if is_txt_format: + def format_value(value): if isinstance(value, list): - return ", ".join([ str(it) for it in value ]) + return ", ".join([str(it) for it in value]) return value metadata = { @@ -208,9 +211,10 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD return entries -def get_printable_request(req: GenerateImageRequest, task_data: TaskData): +def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData): req_metadata = req.dict() task_data_metadata = task_data.dict() + task_data_metadata.update(output_format.dict()) app_config = app.getConfig() using_diffusers = app_config.get("test_diffusers", False) @@ -224,6 +228,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData): metadata[key] = task_data_metadata[key] elif key == "use_embedding_models" and using_diffusers: embeddings_extensions = {".pt", ".bin", ".safetensors"} + def scan_directory(directory_path: str): used_embeddings = [] for entry in os.scandir(directory_path): @@ -232,15 +237,18 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData): if entry_extension not in embeddings_extensions: continue - embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])") + embedding_name_regex = regex.compile( + r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])" + ) if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt): used_embeddings.append(entry.path) elif entry.is_dir(): used_embeddings.extend(scan_directory(entry.path)) return used_embeddings + used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings")) metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None - + # Clean up the metadata if req.init_image is None and "prompt_strength" in metadata: del metadata["prompt_strength"] @@ -254,7 +262,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData): del metadata["latent_upscaler_steps"] if not using_diffusers: - for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata): + for key in ( + x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata + ): del metadata[key] return metadata diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js index 39e06ed7..c584a609 100644 --- a/ui/media/js/engine.js +++ b/ui/media/js/engine.js @@ -1047,17 +1047,22 @@ } } class FilterTask extends Task { - constructor(options = {}) {} + constructor(options = {}) { + super(options) + } /** Send current task to server. * @param {*} [timeout=-1] Optional timeout value in ms * @returns the response from the render request. * @memberof Task */ async post(timeout = -1) { - let jsonResponse = await super.post("/filter", timeout) + let res = await super.post("/filter", timeout) //this._setId(jsonResponse.task) this._setStatus(TaskStatus.waiting) + + return res } + checkReqBody() {} enqueue(progressCallback) { return Task.enqueueNew(this, FilterTask, progressCallback) } @@ -1068,6 +1073,20 @@ if (this.isStopped) { return } + + this._setStatus(TaskStatus.pending) + progressCallback?.call(this, { reqBody: this._reqBody }) + Object.freeze(this._reqBody) + + // Post task request to backend + let renderRes = undefined + try { + renderRes = yield this.post() + yield progressCallback?.call(this, { renderResponse: renderRes }) + } catch (e) { + yield progressCallback?.call(this, { detail: e.message }) + throw e + } } static start(task, progressCallback) { if (typeof task !== "object") {