Mega refactor of the task processing and rendering logic; Split filter into a separate task, and add support for running filter tasks individually; Change the format for sending model and filter data from the API, but maintain backwards compatibility for now with the old API

This commit is contained in:
cmdr2 2023-07-28 18:57:28 +05:30
parent b93c624efa
commit e61549e0cd
11 changed files with 626 additions and 300 deletions

View File

@ -5,7 +5,7 @@ import traceback
from typing import Union from typing import Union
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import TaskData from easydiffusion.types import ModelsData
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit import Context from sdkit import Context
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
@ -57,7 +57,9 @@ def init():
def load_default_models(context: Context): def load_default_models(context: Context):
set_vram_optimizations(context) from easydiffusion import runtime
runtime.set_vram_optimizations(context)
config = app.getConfig() config = app.getConfig()
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings") context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
@ -138,43 +140,32 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?") raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
def reload_models_if_necessary(context: Context, task_data: TaskData): def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else ""
upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else ""
model_paths_in_req = {
"stable-diffusion": task_data.use_stable_diffusion_model,
"vae": task_data.use_vae_model,
"hypernetwork": task_data.use_hypernetwork_model,
"codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None,
"gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None,
"realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None,
"latent_upscaler": True if "latent_upscaler" in upscale_lower else None,
"nsfw_checker": True if task_data.block_nsfw else None,
"lora": task_data.use_lora_model,
}
models_to_reload = { models_to_reload = {
model_type: path model_type: path
for model_type, path in model_paths_in_req.items() for model_type, path in models_data.model_paths.items()
if context.model_paths.get(model_type) != path if context.model_paths.get(model_type) != path
} }
if task_data.codeformer_upscale_faces: if models_data.model_paths.get("codeformer"):
if "realesrgan" not in models_to_reload and "realesrgan" not in context.models: if "realesrgan" not in models_to_reload and "realesrgan" not in context.models:
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"] default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan") models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None: elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None:
del models_to_reload["realesrgan"] # don't unload realesrgan del models_to_reload["realesrgan"] # don't unload realesrgan
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD for model_type in models_to_force_reload:
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"] if model_type not in models_data.model_paths:
continue
models_to_reload[model_type] = models_data.model_paths[model_type]
for model_type, model_path_in_req in models_to_reload.items(): for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn = unload_model if context.model_paths[model_type] is None else load_model
extra_params = models_data.model_params.get(model_type, {})
try: try:
action_fn(context, model_type, scan_model=False) # we've scanned them already action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already
if model_type in context.model_load_errors: if model_type in context.model_load_errors:
del context.model_load_errors[model_type] del context.model_load_errors[model_type]
except Exception as e: except Exception as e:
@ -183,24 +174,15 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
def resolve_model_paths(task_data: TaskData): def resolve_model_paths(models_data: ModelsData):
task_data.use_stable_diffusion_model = resolve_model_to_use( model_paths = models_data.model_paths
task_data.use_stable_diffusion_model, model_type="stable-diffusion" for model_type in model_paths:
) if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") continue
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") if model_type == "codeformer":
task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora")
if task_data.use_face_correction:
if "gfpgan" in task_data.use_face_correction.lower():
model_type = "gfpgan"
elif "codeformer" in task_data.use_face_correction.lower():
model_type = "codeformer"
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type) model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
def fail_if_models_did_not_load(context: Context): def fail_if_models_did_not_load(context: Context):
@ -235,17 +217,6 @@ def download_if_necessary(model_type: str, file_name: str, model_id: str):
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR) download_model(model_type, model_id, download_base_dir=app.MODELS_DIR)
def set_vram_optimizations(context: Context):
config = app.getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level
return True
return False
def migrate_legacy_model_location(): def migrate_legacy_model_location():
'Move the models inside the legacy "stable-diffusion" folder, to their respective folders' 'Move the models inside the legacy "stable-diffusion" folder, to their respective folders'
@ -266,16 +237,6 @@ def any_model_exists(model_type: str) -> bool:
return False return False
def set_clip_skip(context: Context, task_data: TaskData):
clip_skip = task_data.clip_skip
if clip_skip != context.clip_skip:
context.clip_skip = clip_skip
return True
return False
def make_model_folders(): def make_model_folders():
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type) model_dir_path = os.path.join(app.MODELS_DIR, model_type)

View File

@ -0,0 +1,53 @@
"""
A runtime that runs on a specific device (in a thread).
It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context.
This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function.
"""
from easydiffusion import device_manager
from easydiffusion.utils import log
from sdkit import Context
from sdkit.utils import get_device_usage
context = Context() # thread-local
"""
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device):
"""
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
"""
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
context.model_load_errors = {}
context.enable_codeformer = True
from easydiffusion import app
app_config = app.getConfig()
context.test_diffusers = (
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
)
log.info("Device usage during initialization:")
get_device_usage(device, log_info=True, process_usage_only=False)
device_manager.device_init(context, device)
def set_vram_optimizations(context: Context):
from easydiffusion import app
config = app.getConfig()
vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level
return True
return False

View File

@ -9,7 +9,16 @@ import traceback
from typing import List, Union from typing import List, Union
from easydiffusion import app, model_manager, task_manager from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData from easydiffusion.tasks import RenderTask, FilterTask
from easydiffusion.types import (
GenerateImageRequest,
FilterImageRequest,
MergeRequest,
TaskData,
ModelsData,
OutputFormatData,
convert_legacy_render_req_to_new,
)
from easydiffusion.utils import log from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
@ -97,6 +106,10 @@ def init():
def render(req: dict): def render(req: dict):
return render_internal(req) return render_internal(req)
@server_api.post("/filter")
def render(req: dict):
return filter_internal(req)
@server_api.post("/model/merge") @server_api.post("/model/merge")
def model_merge(req: dict): def model_merge(req: dict):
print(req) print(req)
@ -228,9 +241,13 @@ def ping_internal(session_id: str = None):
def render_internal(req: dict): def render_internal(req: dict):
try: try:
req = convert_legacy_render_req_to_new(req)
# separate out the request data into rendering and task-specific data # separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req) task_data: TaskData = TaskData.parse_obj(req)
models_data: ModelsData = ModelsData.parse_obj(req)
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
# Overwrite user specified save path # Overwrite user specified save path
config = app.getConfig() config = app.getConfig()
@ -240,28 +257,53 @@ def render_internal(req: dict):
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
app.save_to_config( app.save_to_config(
task_data.use_stable_diffusion_model, models_data.model_paths.get("stable-diffusion"),
task_data.use_vae_model, models_data.model_paths.get("vae"),
task_data.use_hypernetwork_model, models_data.model_paths.get("hypernetwork"),
task_data.vram_usage_level, task_data.vram_usage_level,
) )
# enqueue the task # enqueue the task
new_task = task_manager.render(render_req, task_data) task = RenderTask(render_req, task_data, models_data, output_format)
return enqueue_task(task)
except HTTPException as e:
raise e
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def filter_internal(req: dict):
try:
session_id = req.get("session_id", "session")
filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req)
models_data: ModelsData = ModelsData.parse_obj(req)
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
# enqueue the task
task = FilterTask(filter_req, session_id, models_data, output_format)
return enqueue_task(task)
except HTTPException as e:
raise e
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def enqueue_task(task):
try:
task_manager.enqueue_task(task)
response = { response = {
"status": str(task_manager.current_state), "status": str(task_manager.current_state),
"queue": len(task_manager.tasks_queue), "queue": len(task_manager.tasks_queue),
"stream": f"/image/stream/{id(new_task)}", "stream": f"/image/stream/{task.id}",
"task": id(new_task), "task": task.id,
} }
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def model_merge_internal(req: dict): def model_merge_internal(req: dict):

View File

@ -17,7 +17,7 @@ from typing import Any, Hashable
import torch import torch
from easydiffusion import device_manager from easydiffusion import device_manager
from easydiffusion.types import GenerateImageRequest, TaskData from easydiffusion.tasks import Task
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit.utils import gc from sdkit.utils import gc
@ -27,6 +27,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
MAX_OVERLOAD_ALLOWED_RATIO = 2 # i.e. 2x pending tasks compared to the number of render threads
class SymbolClass(type): # Print nicely formatted Symbol names. class SymbolClass(type): # Print nicely formatted Symbol names.
@ -58,46 +59,6 @@ class ServerStates:
pass pass
class RenderTask: # Task with output queue and completion lock.
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
task_data.request_id = id(self)
self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except queue.Empty as e:
yield
@property
def status(self):
if self.lock.locked():
return "running"
if isinstance(self.error, StopAsyncIteration):
return "stopped"
if self.error:
return "error"
if not self.buffer_queue.empty():
return "buffer"
if self.response:
return "completed"
return "pending"
@property
def is_pending(self):
return bool(not self.response and not self.error)
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache: class DataCache:
def __init__(self): def __init__(self):
@ -123,8 +84,8 @@ class DataCache:
# Remove Items # Remove Items
for key in to_delete: for key in to_delete:
(_, val) = self._base[key] (_, val) = self._base[key]
if isinstance(val, RenderTask): if isinstance(val, Task):
log.debug(f"RenderTask {key} expired. Data removed.") log.debug(f"Task {key} expired. Data removed.")
elif isinstance(val, SessionState): elif isinstance(val, SessionState):
log.debug(f"Session {key} expired. Data removed.") log.debug(f"Session {key} expired. Data removed.")
else: else:
@ -220,8 +181,8 @@ class SessionState:
tasks.append(task) tasks.append(task)
return tasks return tasks
def put(self, task, ttl=TASK_TTL): def put(self, task: Task, ttl=TASK_TTL):
task_id = id(task) task_id = task.id
self._tasks_ids.append(task_id) self._tasks_ids.append(task_id)
if not task_cache.put(task_id, task, ttl): if not task_cache.put(task_id, task, ttl):
return False return False
@ -230,11 +191,16 @@ class SessionState:
return True return True
def keep_task_alive(task: Task):
task_cache.keep(task.id, TASK_TTL)
session_cache.keep(task.session_id, TASK_TTL)
def thread_get_next_task(): def thread_get_next_task():
from easydiffusion import renderer from easydiffusion import runtime
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.") log.warn(f"Render thread on device: {runtime.context.device} failed to acquire manager lock.")
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -242,7 +208,7 @@ def thread_get_next_task():
task = None task = None
try: # Select a render task. try: # Select a render task.
for queued_task in tasks_queue: for queued_task in tasks_queue:
if queued_task.render_device and renderer.context.device != queued_task.render_device: if queued_task.render_device and runtime.context.device != queued_task.render_device:
# Is asking for a specific render device. # Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0: if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one. continue # requested device alive, skip current one.
@ -251,7 +217,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + " is not currently active.") queued_task.error = Exception(queued_task.render_device + " is not currently active.")
task = queued_task task = queued_task
break break
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1: if not queued_task.render_device and runtime.context.device == "cpu" and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive. # not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
@ -266,19 +232,19 @@ def thread_get_next_task():
def thread_render(device): def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from easydiffusion import model_manager, renderer from easydiffusion import model_manager, runtime
try: try:
renderer.init(device) runtime.init(device)
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
"device": renderer.context.device, "device": runtime.context.device,
"device_name": renderer.context.device_name, "device_name": runtime.context.device_name,
"alive": True, "alive": True,
} }
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
model_manager.load_default_models(renderer.context) model_manager.load_default_models(runtime.context)
current_state = ServerStates.Online current_state = ServerStates.Online
except Exception as e: except Exception as e:
@ -290,8 +256,8 @@ def thread_render(device):
session_cache.clean() session_cache.clean()
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]["alive"]: if not weak_thread_data[threading.current_thread()]["alive"]:
log.info(f"Shutting down thread for device {renderer.context.device}") log.info(f"Shutting down thread for device {runtime.context.device}")
model_manager.unload_all(renderer.context) model_manager.unload_all(runtime.context)
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
@ -311,62 +277,31 @@ def thread_render(device):
task.response = {"status": "failed", "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}") log.info(f"Session {task.session_id} starting task {task.id} on {runtime.context.device_name}")
if not task.lock.acquire(blocking=False): if not task.lock.acquire(blocking=False):
raise Exception("Got locked task from queue.") raise Exception("Got locked task from queue.")
try: try:
task.run()
def step_callback():
global current_state_error
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
if (
isinstance(current_state_error, SystemExit)
or isinstance(current_state_error, StopAsyncIteration)
or isinstance(task.error, StopAsyncIteration)
):
renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
model_manager.fail_if_models_did_not_load(renderer.context)
current_state = ServerStates.Rendering
task.response = renderer.make_images(
task.render_request,
task.task_data,
task.buffer_queue,
task.temp_images,
step_callback,
)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) keep_task_alive(task)
session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = str(e) task.error = str(e)
task.response = {"status": "failed", "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
finally: finally:
gc(renderer.context) gc(runtime.context)
task.lock.release() task.lock.release()
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL) keep_task_alive(task)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!") log.info(f"Session {task.session_id} task {task.id} cancelled!")
elif task.error is not None: elif task.error is not None:
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!") log.info(f"Session {task.session_id} task {task.id} failed!")
else: else:
log.info( log.info(f"Session {task.session_id} task {task.id} completed by {runtime.context.device_name}.")
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
)
current_state = ServerStates.Online current_state = ServerStates.Online
@ -548,28 +483,27 @@ def shutdown_event(): # Signal render thread to close on shutdown
current_state_error = SystemExit("Application shutting down.") current_state_error = SystemExit("Application shutting down.")
def render(render_req: GenerateImageRequest, task_data: TaskData): def enqueue_task(task: Task):
current_thread_count = is_alive() current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError("Rendering thread has died.") raise ChildProcessError("Rendering thread has died.")
# Alive, check if task in cache # Alive, check if task in cache
session = get_cached_session(task_data.session_id, update_ttl=True) session = get_cached_session(task.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks): if len(pending_tasks) > current_thread_count * MAX_OVERLOAD_ALLOWED_RATIO:
raise ConnectionRefusedError( raise ConnectionRefusedError(
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}." f"Session {task.session_id} already has {len(pending_tasks)} pending tasks, with {current_thread_count} workers."
) )
new_task = RenderTask(render_req, task_data) if session.put(task, TASK_TTL):
if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests. # Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would. # Tries to force session.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2): if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try: try:
tasks_queue.append(new_task) tasks_queue.append(task)
idle_event.set() idle_event.set()
return new_task return task
finally: finally:
manager_lock.release() manager_lock.release()
raise RuntimeError("Failed to add task to cache.") raise RuntimeError("Failed to add task to cache.")

View File

@ -0,0 +1,3 @@
from .task import Task
from .render_images import RenderTask
from .filter_images import FilterTask

View File

@ -0,0 +1,110 @@
import json
import pprint
from sdkit.filter import apply_filters
from sdkit.models import load_model
from sdkit.utils import img_to_base64_str, log
from easydiffusion import model_manager, runtime
from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData
from .task import Task
class FilterTask(Task):
"For applying filters to input images"
def __init__(
self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData
):
super().__init__(session_id)
self.request = req
self.models_data = models_data
self.output_format = output_format
# convert to multi-filter format, if necessary
if isinstance(req.filter, str):
req.filter_params = {req.filter: req.filter_params}
req.filter = [req.filter]
if not isinstance(req.image, list):
req.image = [req.image]
def run(self):
"Runs the image filtering task on the assigned thread"
context = runtime.context
model_manager.resolve_model_paths(self.models_data)
model_manager.reload_models_if_necessary(context, self.models_data)
model_manager.fail_if_models_did_not_load(context)
print_task_info(self.request, self.models_data, self.output_format)
images = filter_images(context, self.request.image, self.request.filter, self.request.filter_params)
output_format = self.output_format
images = [
img_to_base64_str(
img, output_format.output_format, output_format.output_quality, output_format.output_lossless
)
for img in images
]
res = FilterImageResponse(self.request, self.models_data, images=images)
res = res.json()
self.buffer_queue.put(json.dumps(res))
log.info("Filter task completed")
self.response = res
def filter_images(context, images, filters, filter_params={}):
filters = filters if isinstance(filters, list) else [filters]
for filter_name in filters:
params = filter_params.get(filter_name, {})
previous_state = before_filter(context, filter_name, params)
try:
images = apply_filters(context, filter_name, images, **params)
finally:
after_filter(context, filter_name, params, previous_state)
return images
def before_filter(context, filter_name, filter_params):
if filter_name == "codeformer":
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
upscale_faces = filter_params.get("upscale_faces", False)
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths.get("realesrgan")
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
return prev_realesrgan_path
def after_filter(context, filter_name, filter_params, previous_state):
if filter_name == "codeformer":
prev_realesrgan_path = previous_state
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData):
req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}")

View File

@ -3,70 +3,109 @@ import pprint
import queue import queue
import time import time
from easydiffusion import device_manager from easydiffusion import model_manager, runtime
from easydiffusion.types import GenerateImageRequest from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData
from easydiffusion.types import Image as ResponseImage from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import Response, TaskData, UserInitiatedStop from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
from easydiffusion.utils import get_printable_request, log, save_images_to_disk from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit import Context
from sdkit.filter import apply_filters
from sdkit.generate import generate_images from sdkit.generate import generate_images
from sdkit.models import load_model
from sdkit.utils import ( from sdkit.utils import (
diffusers_latent_samples_to_images, diffusers_latent_samples_to_images,
gc, gc,
img_to_base64_str, img_to_base64_str,
img_to_buffer, img_to_buffer,
latent_samples_to_images, latent_samples_to_images,
get_device_usage,
) )
context = Context() # thread-local from .task import Task
""" from .filter_images import filter_images
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
"""
def init(device): class RenderTask(Task):
""" "For image generation"
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
"""
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
context.model_load_errors = {}
context.enable_codeformer = True
from easydiffusion import app def __init__(
self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
):
super().__init__(task_data.session_id)
app_config = app.getConfig() task_data.request_id = self.id
context.test_diffusers = ( self.render_request: GenerateImageRequest = req # Initial Request
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main" self.task_data: TaskData = task_data
) self.models_data = models_data
self.output_format = output_format
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
log.info("Device usage during initialization:") def run(self):
get_device_usage(device, log_info=True, process_usage_only=False) "Runs the image generation task on the assigned thread"
device_manager.device_init(context, device) from easydiffusion import task_manager
context = runtime.context
def step_callback():
task_manager.keep_task_alive(self)
task_manager.current_state = task_manager.ServerStates.Rendering
if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance(
self.error, StopAsyncIteration
):
context.stop_processing = True
if isinstance(task_manager.current_state_error, StopAsyncIteration):
self.error = task_manager.current_state_error
task_manager.current_state_error = None
log.info(f"Session {self.session_id} sent cancel signal for task {self.id}")
task_manager.current_state = task_manager.ServerStates.LoadingModel
model_manager.resolve_model_paths(self.models_data)
models_to_force_reload = []
if runtime.set_vram_optimizations(context) or self.has_clip_skip_changed(context):
models_to_force_reload.append("stable-diffusion")
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
model_manager.fail_if_models_did_not_load(context)
task_manager.current_state = task_manager.ServerStates.Rendering
self.response = make_images(
context,
self.render_request,
self.task_data,
self.models_data,
self.output_format,
self.buffer_queue,
self.temp_images,
step_callback,
)
def has_clip_skip_changed(self, context):
if not context.test_diffusers:
return False
model = context.models["stable-diffusion"]
new_clip_skip = self.models_data.model_params.get("stable-diffusion", {}).get("clip_skip", False)
return model["clip_skip"] != new_clip_skip
def make_images( def make_images(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
): ):
context.stop_processing = False context.stop_processing = False
print_task_info(req, task_data) print_task_info(req, task_data, models_data, output_format)
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) images, seeds = make_images_internal(
context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback
)
res = Response( res = GenerateImageResponse(
req, req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format)
task_data,
images=construct_response(images, seeds, task_data, base_seed=req.seed),
) )
res = res.json() res = res.json()
data_queue.put(json.dumps(res)) data_queue.put(json.dumps(res))
@ -75,21 +114,32 @@ def make_images(
return res return res
def print_task_info(req: GenerateImageRequest, task_data: TaskData): def print_task_info(
req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[") req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
):
req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
log.info(f"request: {req_str}") log.info(f"request: {req_str}")
log.info(f"task data: {task_str}") log.info(f"task data: {task_str}")
# log.info(f"models data: {models_data}")
log.info(f"output format: {output_format}")
def make_images_internal( def make_images_internal(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
): ):
images, user_stopped = generate_images_internal( images, user_stopped = generate_images_internal(
context,
req, req,
task_data, task_data,
data_queue, data_queue,
@ -98,11 +148,14 @@ def make_images_internal(
task_data.stream_image_progress, task_data.stream_image_progress,
task_data.stream_image_progress_interval, task_data.stream_image_progress_interval,
) )
gc(context) gc(context)
filtered_images = filter_images(req, task_data, images, user_stopped)
filters, filter_params = task_data.filters, task_data.filter_params
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
if task_data.save_to_disk_path is not None: if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data) save_images_to_disk(images, filtered_images, req, task_data, output_format)
seeds = [*range(req.seed, req.seed + len(images))] seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images: if task_data.show_only_filtered_image or filtered_images is images:
@ -112,6 +165,7 @@ def make_images_internal(
def generate_images_internal( def generate_images_internal(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
data_queue: queue.Queue, data_queue: queue.Queue,
@ -123,6 +177,7 @@ def generate_images_internal(
context.temp_images.clear() context.temp_images.clear()
callback = make_step_callback( callback = make_step_callback(
context,
req, req,
task_data, task_data,
data_queue, data_queue,
@ -155,65 +210,14 @@ def generate_images_internal(
return images, user_stopped return images, user_stopped
def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped): def construct_response(images: list, seeds: list, output_format: OutputFormatData):
if user_stopped:
return images
if task_data.block_nsfw:
images = apply_filters(context, "nsfw_checker", images)
if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths["realesrgan"]
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
try:
images = apply_filters(
context,
"codeformer",
images,
upscale_faces=task_data.codeformer_upscale_faces,
codeformer_fidelity=task_data.codeformer_fidelity,
)
finally:
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
images = apply_filters(context, "gfpgan", images)
if task_data.use_upscale:
if "realesrgan" in task_data.use_upscale.lower():
images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount)
elif task_data.use_upscale == "latent_upscaler":
images = apply_filters(
context,
"latent_upscaler",
images,
scale=task_data.upscale_amount,
latent_upscaler_options={
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"seed": req.seed,
"num_inference_steps": task_data.latent_upscaler_steps,
"guidance_scale": 0,
},
)
return images
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
return [ return [
ResponseImage( ResponseImage(
data=img_to_base64_str( data=img_to_base64_str(
img, img,
task_data.output_format, output_format.output_format,
task_data.output_quality, output_format.output_quality,
task_data.output_lossless, output_format.output_lossless,
), ),
seed=seed, seed=seed,
) )
@ -222,6 +226,7 @@ def construct_response(images: list, seeds: list, task_data: TaskData, base_seed
def make_step_callback( def make_step_callback(
context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
data_queue: queue.Queue, data_queue: queue.Queue,
@ -242,7 +247,7 @@ def make_step_callback(
images = latent_samples_to_images(context, x_samples) images = latent_samples_to_images(context, x_samples)
if task_data.block_nsfw: if task_data.block_nsfw:
images = apply_filters(context, "nsfw_checker", images) images = filter_images(context, images, "nsfw_checker")
for i, img in enumerate(images): for i, img in enumerate(images):
buf = img_to_buffer(img, output_format="JPEG") buf = img_to_buffer(img, output_format="JPEG")

View File

@ -0,0 +1,47 @@
from threading import Lock
from queue import Queue, Empty as EmptyQueueException
from typing import Any
class Task:
"Task with output queue and completion lock"
def __init__(self, session_id):
self.id = id(self)
self.session_id = session_id
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.error: Exception = None
self.lock: Lock = Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: Queue = Queue() # Queue of JSON string segments
self.response: Any = None # Copy of the last reponse
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except EmptyQueueException as e:
yield
@property
def status(self):
if self.lock.locked():
return "running"
if isinstance(self.error, StopAsyncIteration):
return "stopped"
if self.error:
return "error"
if not self.buffer_queue.empty():
return "buffer"
if self.response:
return "completed"
return "pending"
@property
def is_pending(self):
return bool(not self.response and not self.error)
def run(self):
"Override this to implement the task's behavior"
pass

View File

@ -1,4 +1,4 @@
from typing import Any, List, Union from typing import Any, List, Dict, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -17,6 +17,8 @@ class GenerateImageRequest(BaseModel):
init_image: Any = None init_image: Any = None
init_image_mask: Any = None init_image_mask: Any = None
control_image: Any = None
control_alpha: Union[float, List[float]] = None
prompt_strength: float = 0.8 prompt_strength: float = 0.8
preserve_init_image_color_profile = False preserve_init_image_color_profile = False
@ -26,6 +28,35 @@ class GenerateImageRequest(BaseModel):
tiling: str = "none" # "none", "x", "y", "xy" tiling: str = "none" # "none", "x", "y", "xy"
class FilterImageRequest(BaseModel):
image: Any = None
filter: Union[str, List[str]] = None
filter_params: dict = {}
class ModelsData(BaseModel):
"""
Contains the information related to the models involved in a request.
- To load a model: set the relative path(s) to the model in `model_paths`. No effect if already loaded.
- To unload a model: set the model to `None` in `model_paths`. No effect if already unloaded.
Models that aren't present in `model_paths` will not be changed.
"""
model_paths: Dict[str, Union[str, None, List[str]]] = None
"model_type to string path, or list of string paths"
model_params: Dict[str, Dict[str, Any]] = {}
"model_type to dict of parameters"
class OutputFormatData(BaseModel):
output_format: str = "jpeg" # or "png" or "webp"
output_quality: int = 75
output_lossless: bool = False
class TaskData(BaseModel): class TaskData(BaseModel):
request_id: str = None request_id: str = None
session_id: str = "session" session_id: str = "session"
@ -40,12 +71,12 @@ class TaskData(BaseModel):
use_vae_model: Union[str, List[str]] = None use_vae_model: Union[str, List[str]] = None
use_hypernetwork_model: Union[str, List[str]] = None use_hypernetwork_model: Union[str, List[str]] = None
use_lora_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None
use_controlnet_model: Union[str, List[str]] = None
filters: List[str] = []
filter_params: Dict[str, Dict[str, Any]] = {}
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
block_nsfw: bool = False block_nsfw: bool = False
output_format: str = "jpeg" # or "png" or "webp"
output_quality: int = 75
output_lossless: bool = False
metadata_output_format: str = "txt" # or "json" metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False stream_image_progress: bool = False
stream_image_progress_interval: int = 5 stream_image_progress_interval: int = 5
@ -80,24 +111,38 @@ class Image:
} }
class Response: class GenerateImageResponse:
render_request: GenerateImageRequest render_request: GenerateImageRequest
task_data: TaskData task_data: TaskData
models_data: ModelsData
images: list images: list
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list): def __init__(
self,
render_request: GenerateImageRequest,
task_data: TaskData,
models_data: ModelsData,
output_format: OutputFormatData,
images: list,
):
self.render_request = render_request self.render_request = render_request
self.task_data = task_data self.task_data = task_data
self.models_data = models_data
self.output_format = output_format
self.images = images self.images = images
def json(self): def json(self):
del self.render_request.init_image del self.render_request.init_image
del self.render_request.init_image_mask del self.render_request.init_image_mask
task_data = self.task_data.dict()
task_data.update(self.output_format.dict())
res = { res = {
"status": "succeeded", "status": "succeeded",
"render_request": self.render_request.dict(), "render_request": self.render_request.dict(),
"task_data": self.task_data.dict(), "task_data": task_data,
# "models_data": self.models_data.dict(), # haven't migrated the UI to the new format (yet)
"output": [], "output": [],
} }
@ -107,5 +152,102 @@ class Response:
return res return res
class FilterImageResponse:
request: FilterImageRequest
models_data: ModelsData
images: list
def __init__(self, request: FilterImageRequest, models_data: ModelsData, images: list):
self.request = request
self.models_data = models_data
self.images = images
def json(self):
del self.request.image
res = {
"status": "succeeded",
"request": self.request.dict(),
"models_data": self.models_data.dict(),
"output": [],
}
for image in self.images:
res["output"].append(image)
return res
class UserInitiatedStop(Exception): class UserInitiatedStop(Exception):
pass pass
def convert_legacy_render_req_to_new(old_req: dict):
new_req = dict(old_req)
# new keys
model_paths = new_req["model_paths"] = {}
model_params = new_req["model_params"] = {}
filters = new_req["filters"] = []
filter_params = new_req["filter_params"] = {}
# move the model info
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
model_paths["vae"] = old_req.get("use_vae_model")
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
model_paths["lora"] = old_req.get("use_lora_model")
model_paths["controlnet"] = old_req.get("use_controlnet_model")
model_paths["gfpgan"] = old_req.get("use_face_correction", "")
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None
model_paths["codeformer"] = old_req.get("use_face_correction", "")
model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None
model_paths["realesrgan"] = old_req.get("use_upscale", "")
model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None
model_paths["latent_upscaler"] = old_req.get("use_upscale", "")
model_paths["latent_upscaler"] = (
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
)
if old_req.get("block_nsfw"):
model_paths["nsfw_checker"] = "nsfw_checker"
# move the model params
if model_paths["stable-diffusion"]:
model_params["stable-diffusion"] = {"clip_skip": bool(old_req["clip_skip"])}
# move the filter params
if model_paths["realesrgan"]:
filter_params["realesrgan"] = {"scale": int(old_req["upscale_amount"])}
if model_paths["latent_upscaler"]:
filter_params["latent_upscaler"] = {
"prompt": old_req["prompt"],
"negative_prompt": old_req.get("negative_prompt"),
"seed": int(old_req.get("seed", 42)),
"num_inference_steps": int(old_req.get("latent_upscaler_steps", 10)),
"guidance_scale": 0,
}
if model_paths["codeformer"]:
filter_params["codeformer"] = {
"upscale_faces": bool(old_req["codeformer_upscale_faces"]),
"codeformer_fidelity": float(old_req["codeformer_fidelity"]),
}
# set the filters
if old_req.get("block_nsfw"):
filters.append("nsfw_checker")
if model_paths["codeformer"]:
filters.append("codeformer")
elif model_paths["gfpgan"]:
filters.append("gfpgan")
if model_paths["realesrgan"]:
filters.append("realesrgan")
elif model_paths["latent_upscaler"]:
filters.append("latent_upscaler")
return new_req

View File

@ -7,7 +7,7 @@ from datetime import datetime
from functools import reduce from functools import reduce
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import GenerateImageRequest, TaskData from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
from numpy import base_repr from numpy import base_repr
from sdkit.utils import save_dicts, save_images from sdkit.utils import save_dicts, save_images
@ -114,12 +114,14 @@ def format_file_name(
return filename_regex.sub("_", format) return filename_regex.sub("_", format)
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): def save_images_to_disk(
images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData
):
now = time.time() now = time.time()
app_config = app.getConfig() app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id") folder_format = app_config.get("folder_format", "$id")
save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
metadata_entries = get_metadata_entries_for_request(req, task_data) metadata_entries = get_metadata_entries_for_request(req, task_data, output_format)
file_number = calculate_img_number(save_dir_path, task_data) file_number = calculate_img_number(save_dir_path, task_data)
make_filename = make_filename_callback( make_filename = make_filename_callback(
app_config.get("filename_format", "$p_$tsb64"), app_config.get("filename_format", "$p_$tsb64"),
@ -134,9 +136,9 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
filtered_images, filtered_images,
save_dir_path, save_dir_path,
file_name=make_filename, file_name=make_filename,
output_format=task_data.output_format, output_format=output_format.output_format,
output_quality=task_data.output_quality, output_quality=output_format.output_quality,
output_lossless=task_data.output_lossless, output_lossless=output_format.output_lossless,
) )
if task_data.metadata_output_format: if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","): for metadata_output_format in task_data.metadata_output_format.split(","):
@ -146,7 +148,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
save_dir_path, save_dir_path,
file_name=make_filename, file_name=make_filename,
output_format=metadata_output_format, output_format=metadata_output_format,
file_format=task_data.output_format, file_format=output_format.output_format,
) )
else: else:
make_filter_filename = make_filename_callback( make_filter_filename = make_filename_callback(
@ -162,17 +164,17 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
images, images,
save_dir_path, save_dir_path,
file_name=make_filename, file_name=make_filename,
output_format=task_data.output_format, output_format=output_format.output_format,
output_quality=task_data.output_quality, output_quality=output_format.output_quality,
output_lossless=task_data.output_lossless, output_lossless=output_format.output_lossless,
) )
save_images( save_images(
filtered_images, filtered_images,
save_dir_path, save_dir_path,
file_name=make_filter_filename, file_name=make_filter_filename,
output_format=task_data.output_format, output_format=output_format.output_format,
output_quality=task_data.output_quality, output_quality=output_format.output_quality,
output_lossless=task_data.output_lossless, output_lossless=output_format.output_lossless,
) )
if task_data.metadata_output_format: if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(","): for metadata_output_format in task_data.metadata_output_format.split(","):
@ -181,20 +183,21 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
metadata_entries, metadata_entries,
save_dir_path, save_dir_path,
file_name=make_filter_filename, file_name=make_filter_filename,
output_format=task_data.metadata_output_format, output_format=metadata_output_format,
file_format=task_data.output_format, file_format=output_format.output_format,
) )
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
metadata = get_printable_request(req, task_data) metadata = get_printable_request(req, task_data, output_format)
# if text, format it in the text format expected by the UI # if text, format it in the text format expected by the UI
is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",") is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",")
if is_txt_format: if is_txt_format:
def format_value(value): def format_value(value):
if isinstance(value, list): if isinstance(value, list):
return ", ".join([ str(it) for it in value ]) return ", ".join([str(it) for it in value])
return value return value
metadata = { metadata = {
@ -208,9 +211,10 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
return entries return entries
def get_printable_request(req: GenerateImageRequest, task_data: TaskData): def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
req_metadata = req.dict() req_metadata = req.dict()
task_data_metadata = task_data.dict() task_data_metadata = task_data.dict()
task_data_metadata.update(output_format.dict())
app_config = app.getConfig() app_config = app.getConfig()
using_diffusers = app_config.get("test_diffusers", False) using_diffusers = app_config.get("test_diffusers", False)
@ -224,6 +228,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
metadata[key] = task_data_metadata[key] metadata[key] = task_data_metadata[key]
elif key == "use_embedding_models" and using_diffusers: elif key == "use_embedding_models" and using_diffusers:
embeddings_extensions = {".pt", ".bin", ".safetensors"} embeddings_extensions = {".pt", ".bin", ".safetensors"}
def scan_directory(directory_path: str): def scan_directory(directory_path: str):
used_embeddings = [] used_embeddings = []
for entry in os.scandir(directory_path): for entry in os.scandir(directory_path):
@ -232,15 +237,18 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
if entry_extension not in embeddings_extensions: if entry_extension not in embeddings_extensions:
continue continue
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])") embedding_name_regex = regex.compile(
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
)
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt): if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
used_embeddings.append(entry.path) used_embeddings.append(entry.path)
elif entry.is_dir(): elif entry.is_dir():
used_embeddings.extend(scan_directory(entry.path)) used_embeddings.extend(scan_directory(entry.path))
return used_embeddings return used_embeddings
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings")) used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
# Clean up the metadata # Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata: if req.init_image is None and "prompt_strength" in metadata:
del metadata["prompt_strength"] del metadata["prompt_strength"]
@ -254,7 +262,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
del metadata["latent_upscaler_steps"] del metadata["latent_upscaler_steps"]
if not using_diffusers: if not using_diffusers:
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata): for key in (
x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata
):
del metadata[key] del metadata[key]
return metadata return metadata

View File

@ -1047,17 +1047,22 @@
} }
} }
class FilterTask extends Task { class FilterTask extends Task {
constructor(options = {}) {} constructor(options = {}) {
super(options)
}
/** Send current task to server. /** Send current task to server.
* @param {*} [timeout=-1] Optional timeout value in ms * @param {*} [timeout=-1] Optional timeout value in ms
* @returns the response from the render request. * @returns the response from the render request.
* @memberof Task * @memberof Task
*/ */
async post(timeout = -1) { async post(timeout = -1) {
let jsonResponse = await super.post("/filter", timeout) let res = await super.post("/filter", timeout)
//this._setId(jsonResponse.task) //this._setId(jsonResponse.task)
this._setStatus(TaskStatus.waiting) this._setStatus(TaskStatus.waiting)
return res
} }
checkReqBody() {}
enqueue(progressCallback) { enqueue(progressCallback) {
return Task.enqueueNew(this, FilterTask, progressCallback) return Task.enqueueNew(this, FilterTask, progressCallback)
} }
@ -1068,6 +1073,20 @@
if (this.isStopped) { if (this.isStopped) {
return return
} }
this._setStatus(TaskStatus.pending)
progressCallback?.call(this, { reqBody: this._reqBody })
Object.freeze(this._reqBody)
// Post task request to backend
let renderRes = undefined
try {
renderRes = yield this.post()
yield progressCallback?.call(this, { renderResponse: renderRes })
} catch (e) {
yield progressCallback?.call(this, { detail: e.message })
throw e
}
} }
static start(task, progressCallback) { static start(task, progressCallback) {
if (typeof task !== "object") { if (typeof task !== "object") {