diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index 528a33b1..9260349f 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -170,7 +170,7 @@ def print_task_info( output_format: OutputFormatData, save_data: SaveToDiskData, ): - req_str = pprint.pformat(get_printable_request(req, task_data, output_format, save_data)).replace("[", "\[") + req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[") @@ -212,7 +212,7 @@ def make_images_internal( filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images if save_data.save_to_disk_path is not None: - save_images_to_disk(images, filtered_images, req, task_data, output_format, save_data) + save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data) seeds = [*range(req.seed, req.seed + len(images))] if task_data.show_only_filtered_image or filtered_images is images: diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 0694ca33..457af921 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -7,7 +7,14 @@ from datetime import datetime from functools import reduce from easydiffusion import app -from easydiffusion.types import GenerateImageRequest, TaskData, RenderTaskData, OutputFormatData, SaveToDiskData +from easydiffusion.types import ( + GenerateImageRequest, + TaskData, + RenderTaskData, + OutputFormatData, + SaveToDiskData, + ModelsData, +) from numpy import base_repr from sdkit.utils import save_dicts, save_images from sdkit.models.model_loader.embeddings import get_embedding_token @@ -122,6 +129,7 @@ def save_images_to_disk( filtered_images: list, req: GenerateImageRequest, task_data: RenderTaskData, + models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData, ): @@ -129,7 +137,7 @@ def save_images_to_disk( app_config = app.getConfig() folder_format = app_config.get("folder_format", "$id") save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) - metadata_entries = get_metadata_entries_for_request(req, task_data, output_format, save_data) + metadata_entries = get_metadata_entries_for_request(req, task_data, models_data, output_format, save_data) file_number = calculate_img_number(save_dir_path, task_data) make_filename = make_filename_callback( app_config.get("filename_format", "$p_$tsb64"), @@ -197,9 +205,13 @@ def save_images_to_disk( def get_metadata_entries_for_request( - req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData + req: GenerateImageRequest, + task_data: RenderTaskData, + models_data: ModelsData, + output_format: OutputFormatData, + save_data: SaveToDiskData, ): - metadata = get_printable_request(req, task_data, output_format, save_data) + metadata = get_printable_request(req, task_data, models_data, output_format, save_data) # if text, format it in the text format expected by the UI is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",") @@ -222,7 +234,11 @@ def get_metadata_entries_for_request( def get_printable_request( - req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData + req: GenerateImageRequest, + task_data: RenderTaskData, + models_data: ModelsData, + output_format: OutputFormatData, + save_data: SaveToDiskData, ): req_metadata = req.dict() task_data_metadata = task_data.dict() @@ -240,27 +256,11 @@ def get_printable_request( elif key in task_data_metadata: metadata[key] = task_data_metadata[key] - if key == "use_embeddings_model" and using_diffusers: - embeddings_extensions = {".pt", ".bin", ".safetensors"} + if key == "use_embeddings_model" and task_data_metadata[key] and using_diffusers: + embeddings_used = models_data.model_paths["embeddings"] + embeddings_used = embeddings_used if isinstance(embeddings_used, list) else [embeddings_used] - def scan_directory(directory_path: str): - used_embeddings = [] - for entry in os.scandir(directory_path): - if entry.is_file(): - # Check if the filename has the right extension - if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)): - continue - embedding_name_regex = regex.compile( - r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])" - ) - if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt): - used_embeddings.append(entry.path) - elif entry.is_dir(): - used_embeddings.extend(scan_directory(entry.path)) - return used_embeddings - - used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings")) - metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None + metadata["use_embeddings_model"] = embeddings_used if len(embeddings_used) > 0 else None # Clean up the metadata if req.init_image is None and "prompt_strength" in metadata: