mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-15 18:09:16 +01:00
Fix incorrect metadata generation of embeddings, by removing duplicated logic. The UI already handles this
This commit is contained in:
parent
b0294f8cbd
commit
ad5641fa3e
@ -170,7 +170,7 @@ def print_task_info(
|
|||||||
output_format: OutputFormatData,
|
output_format: OutputFormatData,
|
||||||
save_data: SaveToDiskData,
|
save_data: SaveToDiskData,
|
||||||
):
|
):
|
||||||
req_str = pprint.pformat(get_printable_request(req, task_data, output_format, save_data)).replace("[", "\[")
|
req_str = pprint.pformat(get_printable_request(req, task_data, models_data, output_format, save_data)).replace("[", "\[")
|
||||||
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||||
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
|
models_data = pprint.pformat(models_data.dict()).replace("[", "\[")
|
||||||
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
|
output_format = pprint.pformat(output_format.dict()).replace("[", "\[")
|
||||||
@ -212,7 +212,7 @@ def make_images_internal(
|
|||||||
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
|
filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images
|
||||||
|
|
||||||
if save_data.save_to_disk_path is not None:
|
if save_data.save_to_disk_path is not None:
|
||||||
save_images_to_disk(images, filtered_images, req, task_data, output_format, save_data)
|
save_images_to_disk(images, filtered_images, req, task_data, models_data, output_format, save_data)
|
||||||
|
|
||||||
seeds = [*range(req.seed, req.seed + len(images))]
|
seeds = [*range(req.seed, req.seed + len(images))]
|
||||||
if task_data.show_only_filtered_image or filtered_images is images:
|
if task_data.show_only_filtered_image or filtered_images is images:
|
||||||
|
@ -7,7 +7,14 @@ from datetime import datetime
|
|||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
from easydiffusion import app
|
from easydiffusion import app
|
||||||
from easydiffusion.types import GenerateImageRequest, TaskData, RenderTaskData, OutputFormatData, SaveToDiskData
|
from easydiffusion.types import (
|
||||||
|
GenerateImageRequest,
|
||||||
|
TaskData,
|
||||||
|
RenderTaskData,
|
||||||
|
OutputFormatData,
|
||||||
|
SaveToDiskData,
|
||||||
|
ModelsData,
|
||||||
|
)
|
||||||
from numpy import base_repr
|
from numpy import base_repr
|
||||||
from sdkit.utils import save_dicts, save_images
|
from sdkit.utils import save_dicts, save_images
|
||||||
from sdkit.models.model_loader.embeddings import get_embedding_token
|
from sdkit.models.model_loader.embeddings import get_embedding_token
|
||||||
@ -122,6 +129,7 @@ def save_images_to_disk(
|
|||||||
filtered_images: list,
|
filtered_images: list,
|
||||||
req: GenerateImageRequest,
|
req: GenerateImageRequest,
|
||||||
task_data: RenderTaskData,
|
task_data: RenderTaskData,
|
||||||
|
models_data: ModelsData,
|
||||||
output_format: OutputFormatData,
|
output_format: OutputFormatData,
|
||||||
save_data: SaveToDiskData,
|
save_data: SaveToDiskData,
|
||||||
):
|
):
|
||||||
@ -129,7 +137,7 @@ def save_images_to_disk(
|
|||||||
app_config = app.getConfig()
|
app_config = app.getConfig()
|
||||||
folder_format = app_config.get("folder_format", "$id")
|
folder_format = app_config.get("folder_format", "$id")
|
||||||
save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
|
save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
|
||||||
metadata_entries = get_metadata_entries_for_request(req, task_data, output_format, save_data)
|
metadata_entries = get_metadata_entries_for_request(req, task_data, models_data, output_format, save_data)
|
||||||
file_number = calculate_img_number(save_dir_path, task_data)
|
file_number = calculate_img_number(save_dir_path, task_data)
|
||||||
make_filename = make_filename_callback(
|
make_filename = make_filename_callback(
|
||||||
app_config.get("filename_format", "$p_$tsb64"),
|
app_config.get("filename_format", "$p_$tsb64"),
|
||||||
@ -197,9 +205,13 @@ def save_images_to_disk(
|
|||||||
|
|
||||||
|
|
||||||
def get_metadata_entries_for_request(
|
def get_metadata_entries_for_request(
|
||||||
req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData
|
req: GenerateImageRequest,
|
||||||
|
task_data: RenderTaskData,
|
||||||
|
models_data: ModelsData,
|
||||||
|
output_format: OutputFormatData,
|
||||||
|
save_data: SaveToDiskData,
|
||||||
):
|
):
|
||||||
metadata = get_printable_request(req, task_data, output_format, save_data)
|
metadata = get_printable_request(req, task_data, models_data, output_format, save_data)
|
||||||
|
|
||||||
# if text, format it in the text format expected by the UI
|
# if text, format it in the text format expected by the UI
|
||||||
is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",")
|
is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",")
|
||||||
@ -222,7 +234,11 @@ def get_metadata_entries_for_request(
|
|||||||
|
|
||||||
|
|
||||||
def get_printable_request(
|
def get_printable_request(
|
||||||
req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData
|
req: GenerateImageRequest,
|
||||||
|
task_data: RenderTaskData,
|
||||||
|
models_data: ModelsData,
|
||||||
|
output_format: OutputFormatData,
|
||||||
|
save_data: SaveToDiskData,
|
||||||
):
|
):
|
||||||
req_metadata = req.dict()
|
req_metadata = req.dict()
|
||||||
task_data_metadata = task_data.dict()
|
task_data_metadata = task_data.dict()
|
||||||
@ -240,27 +256,11 @@ def get_printable_request(
|
|||||||
elif key in task_data_metadata:
|
elif key in task_data_metadata:
|
||||||
metadata[key] = task_data_metadata[key]
|
metadata[key] = task_data_metadata[key]
|
||||||
|
|
||||||
if key == "use_embeddings_model" and using_diffusers:
|
if key == "use_embeddings_model" and task_data_metadata[key] and using_diffusers:
|
||||||
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
embeddings_used = models_data.model_paths["embeddings"]
|
||||||
|
embeddings_used = embeddings_used if isinstance(embeddings_used, list) else [embeddings_used]
|
||||||
|
|
||||||
def scan_directory(directory_path: str):
|
metadata["use_embeddings_model"] = embeddings_used if len(embeddings_used) > 0 else None
|
||||||
used_embeddings = []
|
|
||||||
for entry in os.scandir(directory_path):
|
|
||||||
if entry.is_file():
|
|
||||||
# Check if the filename has the right extension
|
|
||||||
if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)):
|
|
||||||
continue
|
|
||||||
embedding_name_regex = regex.compile(
|
|
||||||
r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])"
|
|
||||||
)
|
|
||||||
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
|
||||||
used_embeddings.append(entry.path)
|
|
||||||
elif entry.is_dir():
|
|
||||||
used_embeddings.extend(scan_directory(entry.path))
|
|
||||||
return used_embeddings
|
|
||||||
|
|
||||||
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
|
||||||
metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None
|
|
||||||
|
|
||||||
# Clean up the metadata
|
# Clean up the metadata
|
||||||
if req.init_image is None and "prompt_strength" in metadata:
|
if req.init_image is None and "prompt_strength" in metadata:
|
||||||
|
Loading…
Reference in New Issue
Block a user