Fix incorrect metadata generation of embeddings, by removing duplicated logic. The UI already handles this

This commit is contained in:
cmdr2 2023-09-01 19:52:20 +05:30
parent b0294f8cbd
commit ad5641fa3e
2 changed files with 27 additions and 27 deletions

View File

@ -170,7 +170,7 @@ def print_task_info(
output_format: OutputFormatData, output_format: OutputFormatData,
save_data: SaveToDiskData, save_data: SaveToDiskData,
): ):
req_str = pprint.pformat(get_printable_request(req, task_data, output_format, save_data)).replace("[", "\[") req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
models_data = pprint.pformat(models_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
output_format = pprint.pformat(output_format.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
@ -212,7 +212,7 @@ def make_images_internal(
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
if save_data.save_to_disk_path is not None: if save_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data, output_format, save_data) save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data)
seeds = [*range(req.seed, req.seed + len(images))] seeds = [*range(req.seed, req.seed + len(images))]
if task_data.show_only_filtered_image or filtered_images is images: if task_data.show_only_filtered_image or filtered_images is images:

View File

@ -7,7 +7,14 @@ from datetime import datetime
from functools import reduce from functools import reduce
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import GenerateImageRequest, TaskData, RenderTaskData, OutputFormatData, SaveToDiskData from easydiffusion.types import (
GenerateImageRequest,
TaskData,
RenderTaskData,
OutputFormatData,
SaveToDiskData,
ModelsData,
)
from numpy import base_repr from numpy import base_repr
from sdkit.utils import save_dicts, save_images from sdkit.utils import save_dicts, save_images
from sdkit.models.model_loader.embeddings import get_embedding_token from sdkit.models.model_loader.embeddings import get_embedding_token
@ -122,6 +129,7 @@ def save_images_to_disk(
filtered_images: list, filtered_images: list,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: RenderTaskData, task_data: RenderTaskData,
models_data: ModelsData,
output_format: OutputFormatData, output_format: OutputFormatData,
save_data: SaveToDiskData, save_data: SaveToDiskData,
): ):
@ -129,7 +137,7 @@ def save_images_to_disk(
app_config = app.getConfig() app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id") folder_format = app_config.get("folder_format", "$id")
save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
metadata_entries = get_metadata_entries_for_request(req, task_data, output_format, save_data) metadata_entries = get_metadata_entries_for_request(req, task_data, models_data, output_format, save_data)
file_number = calculate_img_number(save_dir_path, task_data) file_number = calculate_img_number(save_dir_path, task_data)
make_filename = make_filename_callback( make_filename = make_filename_callback(
app_config.get("filename_format", "$p_$tsb64"), app_config.get("filename_format", "$p_$tsb64"),
@ -197,9 +205,13 @@ def save_images_to_disk(
def get_metadata_entries_for_request( def get_metadata_entries_for_request(
req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData req: GenerateImageRequest,
task_data: RenderTaskData,
models_data: ModelsData,
output_format: OutputFormatData,
save_data: SaveToDiskData,
): ):
metadata = get_printable_request(req, task_data, output_format, save_data) metadata = get_printable_request(req, task_data, models_data, output_format, save_data)
# if text, format it in the text format expected by the UI # if text, format it in the text format expected by the UI
is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",") is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",")
@ -222,7 +234,11 @@ def get_metadata_entries_for_request(
def get_printable_request( def get_printable_request(
req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData req: GenerateImageRequest,
task_data: RenderTaskData,
models_data: ModelsData,
output_format: OutputFormatData,
save_data: SaveToDiskData,
): ):
req_metadata = req.dict() req_metadata = req.dict()
task_data_metadata = task_data.dict() task_data_metadata = task_data.dict()
@ -240,27 +256,11 @@ def get_printable_request(
elif key in task_data_metadata: elif key in task_data_metadata:
metadata[key] = task_data_metadata[key] metadata[key] = task_data_metadata[key]
if key == "use_embeddings_model" and using_diffusers: if key == "use_embeddings_model" and task_data_metadata[key] and using_diffusers:
embeddings_extensions = {".pt", ".bin", ".safetensors"} embeddings_used = models_data.model_paths["embeddings"]
embeddings_used = embeddings_used if isinstance(embeddings_used, list) else [embeddings_used]
def scan_directory(directory_path: str): metadata["use_embeddings_model"] = embeddings_used if len(embeddings_used) > 0 else None
used_embeddings = []
for entry in os.scandir(directory_path):
if entry.is_file():
# Check if the filename has the right extension
if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)):
continue
embedding_name_regex = regex.compile(
r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])"
)
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
used_embeddings.append(entry.path)
elif entry.is_dir():
used_embeddings.extend(scan_directory(entry.path))
return used_embeddings
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None
# Clean up the metadata # Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata: if req.init_image is None and "prompt_strength" in metadata: