Fix incorrect metadata generation of embeddings, by removing duplicated logic. The UI already handles this

This commit is contained in:
cmdr2 2023-09-01 19:52:20 +05:30
parent b0294f8cbd
commit ad5641fa3e
2 changed files with 27 additions and 27 deletions

View File

@ -170,7 +170,7 @@ def print_task_info(
output_format: OutputFormatData,
save_data: SaveToDiskData,
):
req_str = pprint.pformat(get_printable_request(req, task_data, output_format, save_data)).replace("[", "\[")
req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
@ -212,7 +212,7 @@ def make_images_internal(
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
if save_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data, output_format, save_data)
save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data)
seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images:

View File

@ -7,7 +7,14 @@ from datetime import datetime
from functools import reduce
from easydiffusion import app
from easydiffusion.types import GenerateImageRequest, TaskData, RenderTaskData, OutputFormatData, SaveToDiskData
from easydiffusion.types import (
GenerateImageRequest,
TaskData,
RenderTaskData,
OutputFormatData,
SaveToDiskData,
ModelsData,
)
from numpy import base_repr
from sdkit.utils import save_dicts, save_images
from sdkit.models.model_loader.embeddings import get_embedding_token
@ -122,6 +129,7 @@ def save_images_to_disk(
filtered_images: list,
req: GenerateImageRequest,
task_data: RenderTaskData,
models_data: ModelsData,
output_format: OutputFormatData,
save_data: SaveToDiskData,
):
@ -129,7 +137,7 @@ def save_images_to_disk(
app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id")
save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
metadata_entries = get_metadata_entries_for_request(req, task_data, output_format, save_data)
metadata_entries = get_metadata_entries_for_request(req, task_data, models_data, output_format, save_data)
file_number = calculate_img_number(save_dir_path, task_data)
make_filename = make_filename_callback(
app_config.get("filename_format", "$p_$tsb64"),
@ -197,9 +205,13 @@ def save_images_to_disk(
def get_metadata_entries_for_request(
req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData
req: GenerateImageRequest,
task_data: RenderTaskData,
models_data: ModelsData,
output_format: OutputFormatData,
save_data: SaveToDiskData,
):
metadata = get_printable_request(req, task_data, output_format, save_data)
metadata = get_printable_request(req, task_data, models_data, output_format, save_data)
# if text, format it in the text format expected by the UI
is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",")
@ -222,7 +234,11 @@ def get_metadata_entries_for_request(
def get_printable_request(
req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData
req: GenerateImageRequest,
task_data: RenderTaskData,
models_data: ModelsData,
output_format: OutputFormatData,
save_data: SaveToDiskData,
):
req_metadata = req.dict()
task_data_metadata = task_data.dict()
@ -240,27 +256,11 @@ def get_printable_request(
elif key in task_data_metadata:
metadata[key] = task_data_metadata[key]
if key == "use_embeddings_model" and using_diffusers:
embeddings_extensions = {".pt", ".bin", ".safetensors"}
if key == "use_embeddings_model" and task_data_metadata[key] and using_diffusers:
embeddings_used = models_data.model_paths["embeddings"]
embeddings_used = embeddings_used if isinstance(embeddings_used, list) else [embeddings_used]
def scan_directory(directory_path: str):
used_embeddings = []
for entry in os.scandir(directory_path):
if entry.is_file():
# Check if the filename has the right extension
if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)):
continue
embedding_name_regex = regex.compile(
r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])"
)
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
used_embeddings.append(entry.path)
elif entry.is_dir():
used_embeddings.extend(scan_directory(entry.path))
return used_embeddings
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None
metadata["use_embeddings_model"] = embeddings_used if len(embeddings_used) > 0 else None
# Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata: