mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-09 06:59:30 +01:00
Merge pull request #1398 from ogmaresca/save-embeddings-to-metadata
Add embeddings to metadata
This commit is contained in:
commit
3980625be6
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
import regex
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
@ -30,11 +32,12 @@ TASK_TEXT_MAPPING = {
|
|||||||
"lora_alpha": "LoRA Strength",
|
"lora_alpha": "LoRA Strength",
|
||||||
"use_hypernetwork_model": "Hypernetwork model",
|
"use_hypernetwork_model": "Hypernetwork model",
|
||||||
"hypernetwork_strength": "Hypernetwork Strength",
|
"hypernetwork_strength": "Hypernetwork Strength",
|
||||||
|
"use_embedding_models": "Embedding models",
|
||||||
"tiling": "Seamless Tiling",
|
"tiling": "Seamless Tiling",
|
||||||
"use_face_correction": "Use Face Correction",
|
"use_face_correction": "Use Face Correction",
|
||||||
"use_upscale": "Use Upscaling",
|
"use_upscale": "Use Upscaling",
|
||||||
"upscale_amount": "Upscale By",
|
"upscale_amount": "Upscale By",
|
||||||
"latent_upscaler_steps": "Latent Upscaler Steps"
|
"latent_upscaler_steps": "Latent Upscaler Steps",
|
||||||
}
|
}
|
||||||
|
|
||||||
time_placeholders = {
|
time_placeholders = {
|
||||||
@ -202,6 +205,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
|||||||
req_metadata = req.dict()
|
req_metadata = req.dict()
|
||||||
task_data_metadata = task_data.dict()
|
task_data_metadata = task_data.dict()
|
||||||
|
|
||||||
|
app_config = app.getConfig()
|
||||||
|
using_diffusers = app_config.get("test_diffusers", False)
|
||||||
|
|
||||||
# Save the metadata in the order defined in TASK_TEXT_MAPPING
|
# Save the metadata in the order defined in TASK_TEXT_MAPPING
|
||||||
metadata = {}
|
metadata = {}
|
||||||
for key in TASK_TEXT_MAPPING.keys():
|
for key in TASK_TEXT_MAPPING.keys():
|
||||||
@ -209,6 +215,24 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
|||||||
metadata[key] = req_metadata[key]
|
metadata[key] = req_metadata[key]
|
||||||
elif key in task_data_metadata:
|
elif key in task_data_metadata:
|
||||||
metadata[key] = task_data_metadata[key]
|
metadata[key] = task_data_metadata[key]
|
||||||
|
elif key is "use_embedding_models" and using_diffusers:
|
||||||
|
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
||||||
|
def scan_directory(directory_path: str):
|
||||||
|
used_embeddings = []
|
||||||
|
for entry in os.scandir(directory_path):
|
||||||
|
if entry.is_file():
|
||||||
|
entry_extension = os.path.splitext(entry.name)[1]
|
||||||
|
if entry_extension not in embeddings_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])")
|
||||||
|
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
||||||
|
used_embeddings.append(entry.path)
|
||||||
|
elif entry.is_dir():
|
||||||
|
used_embeddings.extend(scan_directory(entry.path))
|
||||||
|
return used_embeddings
|
||||||
|
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
||||||
|
metadata["use_embedding_models"] = ", ".join(used_embeddings) if len(used_embeddings) > 0 else None
|
||||||
|
|
||||||
# Clean up the metadata
|
# Clean up the metadata
|
||||||
if req.init_image is None and "prompt_strength" in metadata:
|
if req.init_image is None and "prompt_strength" in metadata:
|
||||||
@ -222,8 +246,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
|
|||||||
if task_data.use_upscale != "latent_upscaler" and "latent_upscaler_steps" in metadata:
|
if task_data.use_upscale != "latent_upscaler" and "latent_upscaler_steps" in metadata:
|
||||||
del metadata["latent_upscaler_steps"]
|
del metadata["latent_upscaler_steps"]
|
||||||
|
|
||||||
app_config = app.getConfig()
|
if not using_diffusers:
|
||||||
if not app_config.get("test_diffusers", False):
|
|
||||||
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata):
|
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps"] if x in metadata):
|
||||||
del metadata[key]
|
del metadata[key]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user