forked from extern/easydiffusion
Mega refactor of the task processing and rendering logic; Split filter into a separate task, and add support for running filter tasks individually; Change the format for sending model and filter data from the API, but maintain backwards compatibility for now with the old API
This commit is contained in:
parent
b93c624efa
commit
e61549e0cd
@ -5,7 +5,7 @@ import traceback
|
||||
from typing import Union
|
||||
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.types import ModelsData
|
||||
from easydiffusion.utils import log
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
|
||||
@ -57,7 +57,9 @@ def init():
|
||||
|
||||
|
||||
def load_default_models(context: Context):
|
||||
set_vram_optimizations(context)
|
||||
from easydiffusion import runtime
|
||||
|
||||
runtime.set_vram_optimizations(context)
|
||||
|
||||
config = app.getConfig()
|
||||
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
|
||||
@ -138,43 +140,32 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
||||
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
|
||||
|
||||
|
||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else ""
|
||||
upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else ""
|
||||
|
||||
model_paths_in_req = {
|
||||
"stable-diffusion": task_data.use_stable_diffusion_model,
|
||||
"vae": task_data.use_vae_model,
|
||||
"hypernetwork": task_data.use_hypernetwork_model,
|
||||
"codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None,
|
||||
"gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None,
|
||||
"realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None,
|
||||
"latent_upscaler": True if "latent_upscaler" in upscale_lower else None,
|
||||
"nsfw_checker": True if task_data.block_nsfw else None,
|
||||
"lora": task_data.use_lora_model,
|
||||
}
|
||||
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
|
||||
models_to_reload = {
|
||||
model_type: path
|
||||
for model_type, path in model_paths_in_req.items()
|
||||
for model_type, path in models_data.model_paths.items()
|
||||
if context.model_paths.get(model_type) != path
|
||||
}
|
||||
|
||||
if task_data.codeformer_upscale_faces:
|
||||
if models_data.model_paths.get("codeformer"):
|
||||
if "realesrgan" not in models_to_reload and "realesrgan" not in context.models:
|
||||
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
|
||||
models_to_reload["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
|
||||
elif "realesrgan" in models_to_reload and models_to_reload["realesrgan"] is None:
|
||||
del models_to_reload["realesrgan"] # don't unload realesrgan
|
||||
|
||||
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD
|
||||
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
||||
for model_type in models_to_force_reload:
|
||||
if model_type not in models_data.model_paths:
|
||||
continue
|
||||
models_to_reload[model_type] = models_data.model_paths[model_type]
|
||||
|
||||
for model_type, model_path_in_req in models_to_reload.items():
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||
extra_params = models_data.model_params.get(model_type, {})
|
||||
try:
|
||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||
action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already
|
||||
if model_type in context.model_load_errors:
|
||||
del context.model_load_errors[model_type]
|
||||
except Exception as e:
|
||||
@ -183,24 +174,15 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
|
||||
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(
|
||||
task_data.use_stable_diffusion_model, model_type="stable-diffusion"
|
||||
)
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
|
||||
task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora")
|
||||
|
||||
if task_data.use_face_correction:
|
||||
if "gfpgan" in task_data.use_face_correction.lower():
|
||||
model_type = "gfpgan"
|
||||
elif "codeformer" in task_data.use_face_correction.lower():
|
||||
model_type = "codeformer"
|
||||
def resolve_model_paths(models_data: ModelsData):
|
||||
model_paths = models_data.model_paths
|
||||
for model_type in model_paths:
|
||||
if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths
|
||||
continue
|
||||
if model_type == "codeformer":
|
||||
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
|
||||
|
||||
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type)
|
||||
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
||||
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
|
||||
model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
|
||||
|
||||
|
||||
def fail_if_models_did_not_load(context: Context):
|
||||
@ -235,17 +217,6 @@ def download_if_necessary(model_type: str, file_name: str, model_id: str):
|
||||
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR)
|
||||
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
config = app.getConfig()
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
|
||||
if vram_usage_level != context.vram_usage_level:
|
||||
context.vram_usage_level = vram_usage_level
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def migrate_legacy_model_location():
|
||||
'Move the models inside the legacy "stable-diffusion" folder, to their respective folders'
|
||||
|
||||
@ -266,16 +237,6 @@ def any_model_exists(model_type: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def set_clip_skip(context: Context, task_data: TaskData):
|
||||
clip_skip = task_data.clip_skip
|
||||
|
||||
if clip_skip != context.clip_skip:
|
||||
context.clip_skip = clip_skip
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
|
53
ui/easydiffusion/runtime.py
Normal file
53
ui/easydiffusion/runtime.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""
|
||||
A runtime that runs on a specific device (in a thread).
|
||||
|
||||
It can run various tasks like image generation, image filtering, model merge etc by using that thread-local context.
|
||||
|
||||
This creates an `sdkit.Context` that's bound to the device specified while calling the `init()` function.
|
||||
"""
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.utils import log
|
||||
from sdkit import Context
|
||||
from sdkit.utils import get_device_usage
|
||||
|
||||
context = Context() # thread-local
|
||||
"""
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
"""
|
||||
|
||||
|
||||
def init(device):
|
||||
"""
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
"""
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
context.model_load_errors = {}
|
||||
context.enable_codeformer = True
|
||||
|
||||
from easydiffusion import app
|
||||
|
||||
app_config = app.getConfig()
|
||||
context.test_diffusers = (
|
||||
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
|
||||
)
|
||||
|
||||
log.info("Device usage during initialization:")
|
||||
get_device_usage(device, log_info=True, process_usage_only=False)
|
||||
|
||||
device_manager.device_init(context, device)
|
||||
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
from easydiffusion import app
|
||||
|
||||
config = app.getConfig()
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
|
||||
if vram_usage_level != context.vram_usage_level:
|
||||
context.vram_usage_level = vram_usage_level
|
||||
return True
|
||||
|
||||
return False
|
@ -9,7 +9,16 @@ import traceback
|
||||
from typing import List, Union
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
|
||||
from easydiffusion.tasks import RenderTask, FilterTask
|
||||
from easydiffusion.types import (
|
||||
GenerateImageRequest,
|
||||
FilterImageRequest,
|
||||
MergeRequest,
|
||||
TaskData,
|
||||
ModelsData,
|
||||
OutputFormatData,
|
||||
convert_legacy_render_req_to_new,
|
||||
)
|
||||
from easydiffusion.utils import log
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@ -97,6 +106,10 @@ def init():
|
||||
def render(req: dict):
|
||||
return render_internal(req)
|
||||
|
||||
@server_api.post("/filter")
|
||||
def render(req: dict):
|
||||
return filter_internal(req)
|
||||
|
||||
@server_api.post("/model/merge")
|
||||
def model_merge(req: dict):
|
||||
print(req)
|
||||
@ -228,9 +241,13 @@ def ping_internal(session_id: str = None):
|
||||
|
||||
def render_internal(req: dict):
|
||||
try:
|
||||
req = convert_legacy_render_req_to_new(req)
|
||||
|
||||
# separate out the request data into rendering and task-specific data
|
||||
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
|
||||
task_data: TaskData = TaskData.parse_obj(req)
|
||||
models_data: ModelsData = ModelsData.parse_obj(req)
|
||||
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
|
||||
|
||||
# Overwrite user specified save path
|
||||
config = app.getConfig()
|
||||
@ -240,28 +257,53 @@ def render_internal(req: dict):
|
||||
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
|
||||
|
||||
app.save_to_config(
|
||||
task_data.use_stable_diffusion_model,
|
||||
task_data.use_vae_model,
|
||||
task_data.use_hypernetwork_model,
|
||||
models_data.model_paths.get("stable-diffusion"),
|
||||
models_data.model_paths.get("vae"),
|
||||
models_data.model_paths.get("hypernetwork"),
|
||||
task_data.vram_usage_level,
|
||||
)
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(render_req, task_data)
|
||||
task = RenderTask(render_req, task_data, models_data, output_format)
|
||||
return enqueue_task(task)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def filter_internal(req: dict):
|
||||
try:
|
||||
session_id = req.get("session_id", "session")
|
||||
filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req)
|
||||
models_data: ModelsData = ModelsData.parse_obj(req)
|
||||
output_format: OutputFormatData = OutputFormatData.parse_obj(req)
|
||||
|
||||
# enqueue the task
|
||||
task = FilterTask(filter_req, session_id, models_data, output_format)
|
||||
return enqueue_task(task)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def enqueue_task(task):
|
||||
try:
|
||||
task_manager.enqueue_task(task)
|
||||
response = {
|
||||
"status": str(task_manager.current_state),
|
||||
"queue": len(task_manager.tasks_queue),
|
||||
"stream": f"/image/stream/{id(new_task)}",
|
||||
"task": id(new_task),
|
||||
"stream": f"/image/stream/{task.id}",
|
||||
"task": task.id,
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def model_merge_internal(req: dict):
|
||||
|
@ -17,7 +17,7 @@ from typing import Any, Hashable
|
||||
|
||||
import torch
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||
from easydiffusion.tasks import Task
|
||||
from easydiffusion.utils import log
|
||||
from sdkit.utils import gc
|
||||
|
||||
@ -27,6 +27,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||
|
||||
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
||||
MAX_OVERLOAD_ALLOWED_RATIO = 2 # i.e. 2x pending tasks compared to the number of render threads
|
||||
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
@ -58,46 +59,6 @@ class ServerStates:
|
||||
pass
|
||||
|
||||
|
||||
class RenderTask: # Task with output queue and completion lock.
|
||||
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||
task_data.request_id = id(self)
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
self.error: Exception = None
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
|
||||
async def read_buffer_generator(self):
|
||||
try:
|
||||
while not self.buffer_queue.empty():
|
||||
res = self.buffer_queue.get(block=False)
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except queue.Empty as e:
|
||||
yield
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return "running"
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return "stopped"
|
||||
if self.error:
|
||||
return "error"
|
||||
if not self.buffer_queue.empty():
|
||||
return "buffer"
|
||||
if self.response:
|
||||
return "completed"
|
||||
return "pending"
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class DataCache:
|
||||
def __init__(self):
|
||||
@ -123,8 +84,8 @@ class DataCache:
|
||||
# Remove Items
|
||||
for key in to_delete:
|
||||
(_, val) = self._base[key]
|
||||
if isinstance(val, RenderTask):
|
||||
log.debug(f"RenderTask {key} expired. Data removed.")
|
||||
if isinstance(val, Task):
|
||||
log.debug(f"Task {key} expired. Data removed.")
|
||||
elif isinstance(val, SessionState):
|
||||
log.debug(f"Session {key} expired. Data removed.")
|
||||
else:
|
||||
@ -220,8 +181,8 @@ class SessionState:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
def put(self, task, ttl=TASK_TTL):
|
||||
task_id = id(task)
|
||||
def put(self, task: Task, ttl=TASK_TTL):
|
||||
task_id = task.id
|
||||
self._tasks_ids.append(task_id)
|
||||
if not task_cache.put(task_id, task, ttl):
|
||||
return False
|
||||
@ -230,11 +191,16 @@ class SessionState:
|
||||
return True
|
||||
|
||||
|
||||
def keep_task_alive(task: Task):
|
||||
task_cache.keep(task.id, TASK_TTL)
|
||||
session_cache.keep(task.session_id, TASK_TTL)
|
||||
|
||||
|
||||
def thread_get_next_task():
|
||||
from easydiffusion import renderer
|
||||
from easydiffusion import runtime
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
|
||||
log.warn(f"Render thread on device: {runtime.context.device} failed to acquire manager lock.")
|
||||
return None
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
@ -242,7 +208,7 @@ def thread_get_next_task():
|
||||
task = None
|
||||
try: # Select a render task.
|
||||
for queued_task in tasks_queue:
|
||||
if queued_task.render_device and renderer.context.device != queued_task.render_device:
|
||||
if queued_task.render_device and runtime.context.device != queued_task.render_device:
|
||||
# Is asking for a specific render device.
|
||||
if is_alive(queued_task.render_device) > 0:
|
||||
continue # requested device alive, skip current one.
|
||||
@ -251,7 +217,7 @@ def thread_get_next_task():
|
||||
queued_task.error = Exception(queued_task.render_device + " is not currently active.")
|
||||
task = queued_task
|
||||
break
|
||||
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
|
||||
if not queued_task.render_device and runtime.context.device == "cpu" and is_alive() > 1:
|
||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||
task = queued_task
|
||||
@ -266,19 +232,19 @@ def thread_get_next_task():
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from easydiffusion import model_manager, renderer
|
||||
from easydiffusion import model_manager, runtime
|
||||
|
||||
try:
|
||||
renderer.init(device)
|
||||
runtime.init(device)
|
||||
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
"device": renderer.context.device,
|
||||
"device_name": renderer.context.device_name,
|
||||
"device": runtime.context.device,
|
||||
"device_name": runtime.context.device_name,
|
||||
"alive": True,
|
||||
}
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.load_default_models(renderer.context)
|
||||
model_manager.load_default_models(runtime.context)
|
||||
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
@ -290,8 +256,8 @@ def thread_render(device):
|
||||
session_cache.clean()
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]["alive"]:
|
||||
log.info(f"Shutting down thread for device {renderer.context.device}")
|
||||
model_manager.unload_all(renderer.context)
|
||||
log.info(f"Shutting down thread for device {runtime.context.device}")
|
||||
model_manager.unload_all(runtime.context)
|
||||
return
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
@ -311,62 +277,31 @@ def thread_render(device):
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
|
||||
log.info(f"Session {task.session_id} starting task {task.id} on {runtime.context.device_name}")
|
||||
if not task.lock.acquire(blocking=False):
|
||||
raise Exception("Got locked task from queue.")
|
||||
try:
|
||||
task.run()
|
||||
|
||||
def step_callback():
|
||||
global current_state_error
|
||||
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
|
||||
if (
|
||||
isinstance(current_state_error, SystemExit)
|
||||
or isinstance(current_state_error, StopAsyncIteration)
|
||||
or isinstance(task.error, StopAsyncIteration)
|
||||
):
|
||||
renderer.context.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(task.task_data)
|
||||
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||
model_manager.fail_if_models_did_not_load(renderer.context)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = renderer.make_images(
|
||||
task.render_request,
|
||||
task.task_data,
|
||||
task.buffer_queue,
|
||||
task.temp_images,
|
||||
step_callback,
|
||||
)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
keep_task_alive(task)
|
||||
except Exception as e:
|
||||
task.error = str(e)
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
log.error(traceback.format_exc())
|
||||
finally:
|
||||
gc(renderer.context)
|
||||
gc(runtime.context)
|
||||
task.lock.release()
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
|
||||
keep_task_alive(task)
|
||||
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
|
||||
log.info(f"Session {task.session_id} task {task.id} cancelled!")
|
||||
elif task.error is not None:
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
|
||||
log.info(f"Session {task.session_id} task {task.id} failed!")
|
||||
else:
|
||||
log.info(
|
||||
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
|
||||
)
|
||||
log.info(f"Session {task.session_id} task {task.id} completed by {runtime.context.device_name}.")
|
||||
current_state = ServerStates.Online
|
||||
|
||||
|
||||
@ -548,28 +483,27 @@ def shutdown_event(): # Signal render thread to close on shutdown
|
||||
current_state_error = SystemExit("Application shutting down.")
|
||||
|
||||
|
||||
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
def enqueue_task(task: Task):
|
||||
current_thread_count = is_alive()
|
||||
if current_thread_count <= 0: # Render thread is dead
|
||||
raise ChildProcessError("Rendering thread has died.")
|
||||
|
||||
# Alive, check if task in cache
|
||||
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||
session = get_cached_session(task.session_id, update_ttl=True)
|
||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||
if current_thread_count < len(pending_tasks):
|
||||
if len(pending_tasks) > current_thread_count * MAX_OVERLOAD_ALLOWED_RATIO:
|
||||
raise ConnectionRefusedError(
|
||||
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
|
||||
f"Session {task.session_id} already has {len(pending_tasks)} pending tasks, with {current_thread_count} workers."
|
||||
)
|
||||
|
||||
new_task = RenderTask(render_req, task_data)
|
||||
if session.put(new_task, TASK_TTL):
|
||||
if session.put(task, TASK_TTL):
|
||||
# Use twice the normal timeout for adding user requests.
|
||||
# Tries to force session.put to fail before tasks_queue.put would.
|
||||
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||
try:
|
||||
tasks_queue.append(new_task)
|
||||
tasks_queue.append(task)
|
||||
idle_event.set()
|
||||
return new_task
|
||||
return task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
raise RuntimeError("Failed to add task to cache.")
|
||||
|
3
ui/easydiffusion/tasks/__init__.py
Normal file
3
ui/easydiffusion/tasks/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .task import Task
|
||||
from .render_images import RenderTask
|
||||
from .filter_images import FilterTask
|
110
ui/easydiffusion/tasks/filter_images.py
Normal file
110
ui/easydiffusion/tasks/filter_images.py
Normal file
@ -0,0 +1,110 @@
|
||||
import json
|
||||
import pprint
|
||||
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.models import load_model
|
||||
from sdkit.utils import img_to_base64_str, log
|
||||
|
||||
from easydiffusion import model_manager, runtime
|
||||
from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData
|
||||
|
||||
from .task import Task
|
||||
|
||||
|
||||
class FilterTask(Task):
|
||||
"For applying filters to input images"
|
||||
|
||||
def __init__(
|
||||
self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData
|
||||
):
|
||||
super().__init__(session_id)
|
||||
|
||||
self.request = req
|
||||
self.models_data = models_data
|
||||
self.output_format = output_format
|
||||
|
||||
# convert to multi-filter format, if necessary
|
||||
if isinstance(req.filter, str):
|
||||
req.filter_params = {req.filter: req.filter_params}
|
||||
req.filter = [req.filter]
|
||||
|
||||
if not isinstance(req.image, list):
|
||||
req.image = [req.image]
|
||||
|
||||
def run(self):
|
||||
"Runs the image filtering task on the assigned thread"
|
||||
|
||||
context = runtime.context
|
||||
|
||||
model_manager.resolve_model_paths(self.models_data)
|
||||
model_manager.reload_models_if_necessary(context, self.models_data)
|
||||
model_manager.fail_if_models_did_not_load(context)
|
||||
|
||||
print_task_info(self.request, self.models_data, self.output_format)
|
||||
|
||||
images = filter_images(context, self.request.image, self.request.filter, self.request.filter_params)
|
||||
|
||||
output_format = self.output_format
|
||||
images = [
|
||||
img_to_base64_str(
|
||||
img, output_format.output_format, output_format.output_quality, output_format.output_lossless
|
||||
)
|
||||
for img in images
|
||||
]
|
||||
|
||||
res = FilterImageResponse(self.request, self.models_data, images=images)
|
||||
res = res.json()
|
||||
self.buffer_queue.put(json.dumps(res))
|
||||
log.info("Filter task completed")
|
||||
|
||||
self.response = res
|
||||
|
||||
|
||||
def filter_images(context, images, filters, filter_params={}):
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
|
||||
for filter_name in filters:
|
||||
params = filter_params.get(filter_name, {})
|
||||
|
||||
previous_state = before_filter(context, filter_name, params)
|
||||
|
||||
try:
|
||||
images = apply_filters(context, filter_name, images, **params)
|
||||
finally:
|
||||
after_filter(context, filter_name, params, previous_state)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def before_filter(context, filter_name, filter_params):
|
||||
if filter_name == "codeformer":
|
||||
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
|
||||
|
||||
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
|
||||
prev_realesrgan_path = None
|
||||
|
||||
upscale_faces = filter_params.get("upscale_faces", False)
|
||||
if upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
|
||||
prev_realesrgan_path = context.model_paths.get("realesrgan")
|
||||
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
|
||||
load_model(context, "realesrgan")
|
||||
|
||||
return prev_realesrgan_path
|
||||
|
||||
|
||||
def after_filter(context, filter_name, filter_params, previous_state):
|
||||
if filter_name == "codeformer":
|
||||
prev_realesrgan_path = previous_state
|
||||
if prev_realesrgan_path:
|
||||
context.model_paths["realesrgan"] = prev_realesrgan_path
|
||||
load_model(context, "realesrgan")
|
||||
|
||||
|
||||
def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData):
|
||||
req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[")
|
||||
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
|
||||
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
|
||||
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"models data: {models_data}")
|
||||
log.info(f"output format: {output_format}")
|
@ -3,70 +3,109 @@ import pprint
|
||||
import queue
|
||||
import time
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import GenerateImageRequest
|
||||
from easydiffusion import model_manager, runtime
|
||||
from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData
|
||||
from easydiffusion.types import Image as ResponseImage
|
||||
from easydiffusion.types import Response, TaskData, UserInitiatedStop
|
||||
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
|
||||
from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop
|
||||
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
|
||||
from sdkit import Context
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.generate import generate_images
|
||||
from sdkit.models import load_model
|
||||
from sdkit.utils import (
|
||||
diffusers_latent_samples_to_images,
|
||||
gc,
|
||||
img_to_base64_str,
|
||||
img_to_buffer,
|
||||
latent_samples_to_images,
|
||||
get_device_usage,
|
||||
)
|
||||
|
||||
context = Context() # thread-local
|
||||
"""
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
"""
|
||||
from .task import Task
|
||||
from .filter_images import filter_images
|
||||
|
||||
|
||||
def init(device):
|
||||
"""
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
"""
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
context.model_load_errors = {}
|
||||
context.enable_codeformer = True
|
||||
class RenderTask(Task):
|
||||
"For image generation"
|
||||
|
||||
from easydiffusion import app
|
||||
def __init__(
|
||||
self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
|
||||
):
|
||||
super().__init__(task_data.session_id)
|
||||
|
||||
app_config = app.getConfig()
|
||||
context.test_diffusers = (
|
||||
app_config.get("test_diffusers", False) and app_config.get("update_branch", "main") != "main"
|
||||
)
|
||||
task_data.request_id = self.id
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.models_data = models_data
|
||||
self.output_format = output_format
|
||||
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
|
||||
log.info("Device usage during initialization:")
|
||||
get_device_usage(device, log_info=True, process_usage_only=False)
|
||||
def run(self):
|
||||
"Runs the image generation task on the assigned thread"
|
||||
|
||||
device_manager.device_init(context, device)
|
||||
from easydiffusion import task_manager
|
||||
|
||||
context = runtime.context
|
||||
|
||||
def step_callback():
|
||||
task_manager.keep_task_alive(self)
|
||||
task_manager.current_state = task_manager.ServerStates.Rendering
|
||||
|
||||
if isinstance(task_manager.current_state_error, (SystemExit, StopAsyncIteration)) or isinstance(
|
||||
self.error, StopAsyncIteration
|
||||
):
|
||||
context.stop_processing = True
|
||||
if isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||
self.error = task_manager.current_state_error
|
||||
task_manager.current_state_error = None
|
||||
log.info(f"Session {self.session_id} sent cancel signal for task {self.id}")
|
||||
|
||||
task_manager.current_state = task_manager.ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(self.models_data)
|
||||
|
||||
models_to_force_reload = []
|
||||
if runtime.set_vram_optimizations(context) or self.has_clip_skip_changed(context):
|
||||
models_to_force_reload.append("stable-diffusion")
|
||||
|
||||
model_manager.reload_models_if_necessary(context, self.models_data, models_to_force_reload)
|
||||
model_manager.fail_if_models_did_not_load(context)
|
||||
|
||||
task_manager.current_state = task_manager.ServerStates.Rendering
|
||||
self.response = make_images(
|
||||
context,
|
||||
self.render_request,
|
||||
self.task_data,
|
||||
self.models_data,
|
||||
self.output_format,
|
||||
self.buffer_queue,
|
||||
self.temp_images,
|
||||
step_callback,
|
||||
)
|
||||
|
||||
def has_clip_skip_changed(self, context):
|
||||
if not context.test_diffusers:
|
||||
return False
|
||||
|
||||
model = context.models["stable-diffusion"]
|
||||
new_clip_skip = self.models_data.model_params.get("stable-diffusion", {}).get("clip_skip", False)
|
||||
return model["clip_skip"] != new_clip_skip
|
||||
|
||||
|
||||
def make_images(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
output_format: OutputFormatData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
context.stop_processing = False
|
||||
print_task_info(req, task_data)
|
||||
print_task_info(req, task_data, models_data, output_format)
|
||||
|
||||
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
images, seeds = make_images_internal(
|
||||
context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback
|
||||
)
|
||||
|
||||
res = Response(
|
||||
req,
|
||||
task_data,
|
||||
images=construct_response(images, seeds, task_data, base_seed=req.seed),
|
||||
res = GenerateImageResponse(
|
||||
req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format)
|
||||
)
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
@ -75,21 +114,32 @@ def make_images(
|
||||
return res
|
||||
|
||||
|
||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[")
|
||||
def print_task_info(
|
||||
req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData
|
||||
):
|
||||
req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
|
||||
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
|
||||
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"task data: {task_str}")
|
||||
# log.info(f"models data: {models_data}")
|
||||
log.info(f"output format: {output_format}")
|
||||
|
||||
|
||||
def make_images_internal(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
output_format: OutputFormatData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
images, user_stopped = generate_images_internal(
|
||||
context,
|
||||
req,
|
||||
task_data,
|
||||
data_queue,
|
||||
@ -98,11 +148,14 @@ def make_images_internal(
|
||||
task_data.stream_image_progress,
|
||||
task_data.stream_image_progress_interval,
|
||||
)
|
||||
|
||||
gc(context)
|
||||
filtered_images = filter_images(req, task_data, images, user_stopped)
|
||||
|
||||
filters, filter_params = task_data.filters, task_data.filter_params
|
||||
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
save_images_to_disk(images, filtered_images, req, task_data)
|
||||
save_images_to_disk(images, filtered_images, req, task_data, output_format)
|
||||
|
||||
seeds = [*range(req.seed, req.seed + len(images))]
|
||||
if task_data.show_only_filtered_image or filtered_images is images:
|
||||
@ -112,6 +165,7 @@ def make_images_internal(
|
||||
|
||||
|
||||
def generate_images_internal(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
@ -123,6 +177,7 @@ def generate_images_internal(
|
||||
context.temp_images.clear()
|
||||
|
||||
callback = make_step_callback(
|
||||
context,
|
||||
req,
|
||||
task_data,
|
||||
data_queue,
|
||||
@ -155,65 +210,14 @@ def generate_images_internal(
|
||||
return images, user_stopped
|
||||
|
||||
|
||||
def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped):
|
||||
if user_stopped:
|
||||
return images
|
||||
|
||||
if task_data.block_nsfw:
|
||||
images = apply_filters(context, "nsfw_checker", images)
|
||||
|
||||
if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
|
||||
default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
|
||||
prev_realesrgan_path = None
|
||||
if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
|
||||
prev_realesrgan_path = context.model_paths["realesrgan"]
|
||||
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
|
||||
load_model(context, "realesrgan")
|
||||
|
||||
try:
|
||||
images = apply_filters(
|
||||
context,
|
||||
"codeformer",
|
||||
images,
|
||||
upscale_faces=task_data.codeformer_upscale_faces,
|
||||
codeformer_fidelity=task_data.codeformer_fidelity,
|
||||
)
|
||||
finally:
|
||||
if prev_realesrgan_path:
|
||||
context.model_paths["realesrgan"] = prev_realesrgan_path
|
||||
load_model(context, "realesrgan")
|
||||
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
|
||||
images = apply_filters(context, "gfpgan", images)
|
||||
|
||||
if task_data.use_upscale:
|
||||
if "realesrgan" in task_data.use_upscale.lower():
|
||||
images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount)
|
||||
elif task_data.use_upscale == "latent_upscaler":
|
||||
images = apply_filters(
|
||||
context,
|
||||
"latent_upscaler",
|
||||
images,
|
||||
scale=task_data.upscale_amount,
|
||||
latent_upscaler_options={
|
||||
"prompt": req.prompt,
|
||||
"negative_prompt": req.negative_prompt,
|
||||
"seed": req.seed,
|
||||
"num_inference_steps": task_data.latent_upscaler_steps,
|
||||
"guidance_scale": 0,
|
||||
},
|
||||
)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||
def construct_response(images: list, seeds: list, output_format: OutputFormatData):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=img_to_base64_str(
|
||||
img,
|
||||
task_data.output_format,
|
||||
task_data.output_quality,
|
||||
task_data.output_lossless,
|
||||
output_format.output_format,
|
||||
output_format.output_quality,
|
||||
output_format.output_lossless,
|
||||
),
|
||||
seed=seed,
|
||||
)
|
||||
@ -222,6 +226,7 @@ def construct_response(images: list, seeds: list, task_data: TaskData, base_seed
|
||||
|
||||
|
||||
def make_step_callback(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
@ -242,7 +247,7 @@ def make_step_callback(
|
||||
images = latent_samples_to_images(context, x_samples)
|
||||
|
||||
if task_data.block_nsfw:
|
||||
images = apply_filters(context, "nsfw_checker", images)
|
||||
images = filter_images(context, images, "nsfw_checker")
|
||||
|
||||
for i, img in enumerate(images):
|
||||
buf = img_to_buffer(img, output_format="JPEG")
|
47
ui/easydiffusion/tasks/task.py
Normal file
47
ui/easydiffusion/tasks/task.py
Normal file
@ -0,0 +1,47 @@
|
||||
from threading import Lock
|
||||
from queue import Queue, Empty as EmptyQueueException
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Task:
|
||||
"Task with output queue and completion lock"
|
||||
|
||||
def __init__(self, session_id):
|
||||
self.id = id(self)
|
||||
self.session_id = session_id
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.error: Exception = None
|
||||
self.lock: Lock = Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: Queue = Queue() # Queue of JSON string segments
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
|
||||
async def read_buffer_generator(self):
|
||||
try:
|
||||
while not self.buffer_queue.empty():
|
||||
res = self.buffer_queue.get(block=False)
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except EmptyQueueException as e:
|
||||
yield
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return "running"
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return "stopped"
|
||||
if self.error:
|
||||
return "error"
|
||||
if not self.buffer_queue.empty():
|
||||
return "buffer"
|
||||
if self.response:
|
||||
return "completed"
|
||||
return "pending"
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
def run(self):
|
||||
"Override this to implement the task's behavior"
|
||||
pass
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Union
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -17,6 +17,8 @@ class GenerateImageRequest(BaseModel):
|
||||
|
||||
init_image: Any = None
|
||||
init_image_mask: Any = None
|
||||
control_image: Any = None
|
||||
control_alpha: Union[float, List[float]] = None
|
||||
prompt_strength: float = 0.8
|
||||
preserve_init_image_color_profile = False
|
||||
|
||||
@ -26,6 +28,35 @@ class GenerateImageRequest(BaseModel):
|
||||
tiling: str = "none" # "none", "x", "y", "xy"
|
||||
|
||||
|
||||
class FilterImageRequest(BaseModel):
|
||||
image: Any = None
|
||||
filter: Union[str, List[str]] = None
|
||||
filter_params: dict = {}
|
||||
|
||||
|
||||
class ModelsData(BaseModel):
|
||||
"""
|
||||
Contains the information related to the models involved in a request.
|
||||
|
||||
- To load a model: set the relative path(s) to the model in `model_paths`. No effect if already loaded.
|
||||
- To unload a model: set the model to `None` in `model_paths`. No effect if already unloaded.
|
||||
|
||||
Models that aren't present in `model_paths` will not be changed.
|
||||
"""
|
||||
|
||||
model_paths: Dict[str, Union[str, None, List[str]]] = None
|
||||
"model_type to string path, or list of string paths"
|
||||
|
||||
model_params: Dict[str, Dict[str, Any]] = {}
|
||||
"model_type to dict of parameters"
|
||||
|
||||
|
||||
class OutputFormatData(BaseModel):
|
||||
output_format: str = "jpeg" # or "png" or "webp"
|
||||
output_quality: int = 75
|
||||
output_lossless: bool = False
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
request_id: str = None
|
||||
session_id: str = "session"
|
||||
@ -40,12 +71,12 @@ class TaskData(BaseModel):
|
||||
use_vae_model: Union[str, List[str]] = None
|
||||
use_hypernetwork_model: Union[str, List[str]] = None
|
||||
use_lora_model: Union[str, List[str]] = None
|
||||
use_controlnet_model: Union[str, List[str]] = None
|
||||
filters: List[str] = []
|
||||
filter_params: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
show_only_filtered_image: bool = False
|
||||
block_nsfw: bool = False
|
||||
output_format: str = "jpeg" # or "png" or "webp"
|
||||
output_quality: int = 75
|
||||
output_lossless: bool = False
|
||||
metadata_output_format: str = "txt" # or "json"
|
||||
stream_image_progress: bool = False
|
||||
stream_image_progress_interval: int = 5
|
||||
@ -80,24 +111,38 @@ class Image:
|
||||
}
|
||||
|
||||
|
||||
class Response:
|
||||
class GenerateImageResponse:
|
||||
render_request: GenerateImageRequest
|
||||
task_data: TaskData
|
||||
models_data: ModelsData
|
||||
images: list
|
||||
|
||||
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
|
||||
def __init__(
|
||||
self,
|
||||
render_request: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
output_format: OutputFormatData,
|
||||
images: list,
|
||||
):
|
||||
self.render_request = render_request
|
||||
self.task_data = task_data
|
||||
self.models_data = models_data
|
||||
self.output_format = output_format
|
||||
self.images = images
|
||||
|
||||
def json(self):
|
||||
del self.render_request.init_image
|
||||
del self.render_request.init_image_mask
|
||||
|
||||
task_data = self.task_data.dict()
|
||||
task_data.update(self.output_format.dict())
|
||||
|
||||
res = {
|
||||
"status": "succeeded",
|
||||
"render_request": self.render_request.dict(),
|
||||
"task_data": self.task_data.dict(),
|
||||
"task_data": task_data,
|
||||
# "models_data": self.models_data.dict(), # haven't migrated the UI to the new format (yet)
|
||||
"output": [],
|
||||
}
|
||||
|
||||
@ -107,5 +152,102 @@ class Response:
|
||||
return res
|
||||
|
||||
|
||||
class FilterImageResponse:
|
||||
request: FilterImageRequest
|
||||
models_data: ModelsData
|
||||
images: list
|
||||
|
||||
def __init__(self, request: FilterImageRequest, models_data: ModelsData, images: list):
|
||||
self.request = request
|
||||
self.models_data = models_data
|
||||
self.images = images
|
||||
|
||||
def json(self):
|
||||
del self.request.image
|
||||
|
||||
res = {
|
||||
"status": "succeeded",
|
||||
"request": self.request.dict(),
|
||||
"models_data": self.models_data.dict(),
|
||||
"output": [],
|
||||
}
|
||||
|
||||
for image in self.images:
|
||||
res["output"].append(image)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def convert_legacy_render_req_to_new(old_req: dict):
|
||||
new_req = dict(old_req)
|
||||
|
||||
# new keys
|
||||
model_paths = new_req["model_paths"] = {}
|
||||
model_params = new_req["model_params"] = {}
|
||||
filters = new_req["filters"] = []
|
||||
filter_params = new_req["filter_params"] = {}
|
||||
|
||||
# move the model info
|
||||
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
|
||||
model_paths["vae"] = old_req.get("use_vae_model")
|
||||
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
|
||||
model_paths["lora"] = old_req.get("use_lora_model")
|
||||
model_paths["controlnet"] = old_req.get("use_controlnet_model")
|
||||
|
||||
model_paths["gfpgan"] = old_req.get("use_face_correction", "")
|
||||
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None
|
||||
|
||||
model_paths["codeformer"] = old_req.get("use_face_correction", "")
|
||||
model_paths["codeformer"] = model_paths["codeformer"] if "codeformer" in model_paths["codeformer"].lower() else None
|
||||
|
||||
model_paths["realesrgan"] = old_req.get("use_upscale", "")
|
||||
model_paths["realesrgan"] = model_paths["realesrgan"] if "realesrgan" in model_paths["realesrgan"].lower() else None
|
||||
|
||||
model_paths["latent_upscaler"] = old_req.get("use_upscale", "")
|
||||
model_paths["latent_upscaler"] = (
|
||||
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
|
||||
)
|
||||
|
||||
if old_req.get("block_nsfw"):
|
||||
model_paths["nsfw_checker"] = "nsfw_checker"
|
||||
|
||||
# move the model params
|
||||
if model_paths["stable-diffusion"]:
|
||||
model_params["stable-diffusion"] = {"clip_skip": bool(old_req["clip_skip"])}
|
||||
|
||||
# move the filter params
|
||||
if model_paths["realesrgan"]:
|
||||
filter_params["realesrgan"] = {"scale": int(old_req["upscale_amount"])}
|
||||
if model_paths["latent_upscaler"]:
|
||||
filter_params["latent_upscaler"] = {
|
||||
"prompt": old_req["prompt"],
|
||||
"negative_prompt": old_req.get("negative_prompt"),
|
||||
"seed": int(old_req.get("seed", 42)),
|
||||
"num_inference_steps": int(old_req.get("latent_upscaler_steps", 10)),
|
||||
"guidance_scale": 0,
|
||||
}
|
||||
if model_paths["codeformer"]:
|
||||
filter_params["codeformer"] = {
|
||||
"upscale_faces": bool(old_req["codeformer_upscale_faces"]),
|
||||
"codeformer_fidelity": float(old_req["codeformer_fidelity"]),
|
||||
}
|
||||
|
||||
# set the filters
|
||||
if old_req.get("block_nsfw"):
|
||||
filters.append("nsfw_checker")
|
||||
|
||||
if model_paths["codeformer"]:
|
||||
filters.append("codeformer")
|
||||
elif model_paths["gfpgan"]:
|
||||
filters.append("gfpgan")
|
||||
|
||||
if model_paths["realesrgan"]:
|
||||
filters.append("realesrgan")
|
||||
elif model_paths["latent_upscaler"]:
|
||||
filters.append("latent_upscaler")
|
||||
|
||||
return new_req
|
||||
|
@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from functools import reduce
|
||||
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
|
||||
from numpy import base_repr
|
||||
from sdkit.utils import save_dicts, save_images
|
||||
|
||||
@ -114,12 +114,14 @@ def format_file_name(
|
||||
return filename_regex.sub("_", format)
|
||||
|
||||
|
||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||
def save_images_to_disk(
|
||||
images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData
|
||||
):
|
||||
now = time.time()
|
||||
app_config = app.getConfig()
|
||||
folder_format = app_config.get("folder_format", "$id")
|
||||
save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
|
||||
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
||||
metadata_entries = get_metadata_entries_for_request(req, task_data, output_format)
|
||||
file_number = calculate_img_number(save_dir_path, task_data)
|
||||
make_filename = make_filename_callback(
|
||||
app_config.get("filename_format", "$p_$tsb64"),
|
||||
@ -134,9 +136,9 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
output_lossless=task_data.output_lossless,
|
||||
output_format=output_format.output_format,
|
||||
output_quality=output_format.output_quality,
|
||||
output_lossless=output_format.output_lossless,
|
||||
)
|
||||
if task_data.metadata_output_format:
|
||||
for metadata_output_format in task_data.metadata_output_format.split(","):
|
||||
@ -146,7 +148,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
file_format=output_format.output_format,
|
||||
)
|
||||
else:
|
||||
make_filter_filename = make_filename_callback(
|
||||
@ -162,17 +164,17 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
output_lossless=task_data.output_lossless,
|
||||
output_format=output_format.output_format,
|
||||
output_quality=output_format.output_quality,
|
||||
output_lossless=output_format.output_lossless,
|
||||
)
|
||||
save_images(
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
output_lossless=task_data.output_lossless,
|
||||
output_format=output_format.output_format,
|
||||
output_quality=output_format.output_quality,
|
||||
output_lossless=output_format.output_lossless,
|
||||
)
|
||||
if task_data.metadata_output_format:
|
||||
for metadata_output_format in task_data.metadata_output_format.split(","):
|
||||
@ -181,20 +183,21 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
metadata_entries,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
output_format=metadata_output_format,
|
||||
file_format=output_format.output_format,
|
||||
)
|
||||
|
||||
|
||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
metadata = get_printable_request(req, task_data)
|
||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
|
||||
metadata = get_printable_request(req, task_data, output_format)
|
||||
|
||||
# if text, format it in the text format expected by the UI
|
||||
is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",")
|
||||
if is_txt_format:
|
||||
|
||||
def format_value(value):
|
||||
if isinstance(value, list):
|
||||
return ", ".join([ str(it) for it in value ])
|
||||
return ", ".join([str(it) for it in value])
|
||||
return value
|
||||
|
||||
metadata = {
|
||||
@ -208,9 +211,10 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
|
||||
return entries
|
||||
|
||||
|
||||
def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData):
|
||||
req_metadata = req.dict()
|
||||
task_data_metadata = task_data.dict()
|
||||
task_data_metadata.update(output_format.dict())
|
||||
|
||||
app_config = app.getConfig()
|
||||
using_diffusers = app_config.get("test_diffusers", False)
|
||||
@ -224,6 +228,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
metadata[key] = task_data_metadata[key]
|
||||
elif key == "use_embedding_models" and using_diffusers:
|
||||
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
||||
|
||||
def scan_directory(directory_path: str):
|
||||
used_embeddings = []
|
||||
for entry in os.scandir(directory_path):
|
||||
@ -232,15 +237,18 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
if entry_extension not in embeddings_extensions:
|
||||
continue
|
||||
|
||||
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])")
|
||||
embedding_name_regex = regex.compile(
|
||||
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
|
||||
)
|
||||
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
||||
used_embeddings.append(entry.path)
|
||||
elif entry.is_dir():
|
||||
used_embeddings.extend(scan_directory(entry.path))
|
||||
return used_embeddings
|
||||
|
||||
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
||||
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
|
||||
|
||||
|
||||
# Clean up the metadata
|
||||
if req.init_image is None and "prompt_strength" in metadata:
|
||||
del metadata["prompt_strength"]
|
||||
@ -254,7 +262,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
del metadata["latent_upscaler_steps"]
|
||||
|
||||
if not using_diffusers:
|
||||
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata):
|
||||
for key in (
|
||||
x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata
|
||||
):
|
||||
del metadata[key]
|
||||
|
||||
return metadata
|
||||
|
@ -1047,17 +1047,22 @@
|
||||
}
|
||||
}
|
||||
class FilterTask extends Task {
|
||||
constructor(options = {}) {}
|
||||
constructor(options = {}) {
|
||||
super(options)
|
||||
}
|
||||
/** Send current task to server.
|
||||
* @param {*} [timeout=-1] Optional timeout value in ms
|
||||
* @returns the response from the render request.
|
||||
* @memberof Task
|
||||
*/
|
||||
async post(timeout = -1) {
|
||||
let jsonResponse = await super.post("/filter", timeout)
|
||||
let res = await super.post("/filter", timeout)
|
||||
//this._setId(jsonResponse.task)
|
||||
this._setStatus(TaskStatus.waiting)
|
||||
|
||||
return res
|
||||
}
|
||||
checkReqBody() {}
|
||||
enqueue(progressCallback) {
|
||||
return Task.enqueueNew(this, FilterTask, progressCallback)
|
||||
}
|
||||
@ -1068,6 +1073,20 @@
|
||||
if (this.isStopped) {
|
||||
return
|
||||
}
|
||||
|
||||
this._setStatus(TaskStatus.pending)
|
||||
progressCallback?.call(this, { reqBody: this._reqBody })
|
||||
Object.freeze(this._reqBody)
|
||||
|
||||
// Post task request to backend
|
||||
let renderRes = undefined
|
||||
try {
|
||||
renderRes = yield this.post()
|
||||
yield progressCallback?.call(this, { renderResponse: renderRes })
|
||||
} catch (e) {
|
||||
yield progressCallback?.call(this, { detail: e.message })
|
||||
throw e
|
||||
}
|
||||
}
|
||||
static start(task, progressCallback) {
|
||||
if (typeof task !== "object") {
|
||||
|
Loading…
Reference in New Issue
Block a user