mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 08:13:22 +01:00
fix embedding parser and use standard embedding varuable for metadata (#1516)
This commit is contained in:
parent
18049d529a
commit
669d40a9d2
@ -34,7 +34,7 @@ TASK_TEXT_MAPPING = {
|
|||||||
"lora_alpha": "LoRA Strength",
|
"lora_alpha": "LoRA Strength",
|
||||||
"use_hypernetwork_model": "Hypernetwork model",
|
"use_hypernetwork_model": "Hypernetwork model",
|
||||||
"hypernetwork_strength": "Hypernetwork Strength",
|
"hypernetwork_strength": "Hypernetwork Strength",
|
||||||
"use_embedding_models": "Embedding models",
|
"use_embeddings_model": "Embedding models",
|
||||||
"tiling": "Seamless Tiling",
|
"tiling": "Seamless Tiling",
|
||||||
"use_face_correction": "Use Face Correction",
|
"use_face_correction": "Use Face Correction",
|
||||||
"use_upscale": "Use Upscaling",
|
"use_upscale": "Use Upscaling",
|
||||||
@ -228,28 +228,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output
|
|||||||
metadata[key] = req_metadata[key]
|
metadata[key] = req_metadata[key]
|
||||||
elif key in task_data_metadata:
|
elif key in task_data_metadata:
|
||||||
metadata[key] = task_data_metadata[key]
|
metadata[key] = task_data_metadata[key]
|
||||||
elif key == "use_embedding_models" and using_diffusers:
|
|
||||||
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
|
||||||
|
|
||||||
def scan_directory(directory_path: str):
|
|
||||||
used_embeddings = []
|
|
||||||
for entry in os.scandir(directory_path):
|
|
||||||
if entry.is_file():
|
|
||||||
entry_extension = os.path.splitext(entry.name)[1]
|
|
||||||
if entry_extension not in embeddings_extensions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
embedding_name_regex = regex.compile(
|
|
||||||
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
|
|
||||||
)
|
|
||||||
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
|
||||||
used_embeddings.append(entry.path)
|
|
||||||
elif entry.is_dir():
|
|
||||||
used_embeddings.extend(scan_directory(entry.path))
|
|
||||||
return used_embeddings
|
|
||||||
|
|
||||||
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
|
||||||
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
|
|
||||||
|
|
||||||
# Clean up the metadata
|
# Clean up the metadata
|
||||||
if req.init_image is None and "prompt_strength" in metadata:
|
if req.init_image is None and "prompt_strength" in metadata:
|
||||||
|
@ -1306,7 +1306,7 @@ function getCurrentUserRequest() {
|
|||||||
function setEmbeddings(task) {
|
function setEmbeddings(task) {
|
||||||
let prompt = task.reqBody.prompt.toLowerCase()
|
let prompt = task.reqBody.prompt.toLowerCase()
|
||||||
let negativePrompt = task.reqBody.negative_prompt.toLowerCase()
|
let negativePrompt = task.reqBody.negative_prompt.toLowerCase()
|
||||||
let overallPrompt = (prompt + " " + negativePrompt).replaceAll(",", "").split(" ")
|
let overallPrompt = (prompt + " " + negativePrompt)
|
||||||
|
|
||||||
let embeddingsTree = modelsOptions["embeddings"]
|
let embeddingsTree = modelsOptions["embeddings"]
|
||||||
let embeddings = []
|
let embeddings = []
|
||||||
|
Loading…
Reference in New Issue
Block a user