mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
fix embedding parser and use standard embedding varuable for metadata (#1516)
This commit is contained in:
parent
18049d529a
commit
669d40a9d2
@ -34,7 +34,7 @@ TASK_TEXT_MAPPING = {
|
||||
"lora_alpha": "LoRA Strength",
|
||||
"use_hypernetwork_model": "Hypernetwork model",
|
||||
"hypernetwork_strength": "Hypernetwork Strength",
|
||||
"use_embedding_models": "Embedding models",
|
||||
"use_embeddings_model": "Embedding models",
|
||||
"tiling": "Seamless Tiling",
|
||||
"use_face_correction": "Use Face Correction",
|
||||
"use_upscale": "Use Upscaling",
|
||||
@ -228,28 +228,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output
|
||||
metadata[key] = req_metadata[key]
|
||||
elif key in task_data_metadata:
|
||||
metadata[key] = task_data_metadata[key]
|
||||
elif key == "use_embedding_models" and using_diffusers:
|
||||
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
||||
|
||||
def scan_directory(directory_path: str):
|
||||
used_embeddings = []
|
||||
for entry in os.scandir(directory_path):
|
||||
if entry.is_file():
|
||||
entry_extension = os.path.splitext(entry.name)[1]
|
||||
if entry_extension not in embeddings_extensions:
|
||||
continue
|
||||
|
||||
embedding_name_regex = regex.compile(
|
||||
r"(^|[\s,])" + regex.escape(os.path.splitext(entry.name)[0]) + r"([+-]*$|[\s,]|[+-]+[\s,])"
|
||||
)
|
||||
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
||||
used_embeddings.append(entry.path)
|
||||
elif entry.is_dir():
|
||||
used_embeddings.extend(scan_directory(entry.path))
|
||||
return used_embeddings
|
||||
|
||||
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
||||
metadata["use_embedding_models"] = used_embeddings if len(used_embeddings) > 0 else None
|
||||
|
||||
|
||||
# Clean up the metadata
|
||||
if req.init_image is None and "prompt_strength" in metadata:
|
||||
|
@ -1306,7 +1306,7 @@ function getCurrentUserRequest() {
|
||||
function setEmbeddings(task) {
|
||||
let prompt = task.reqBody.prompt.toLowerCase()
|
||||
let negativePrompt = task.reqBody.negative_prompt.toLowerCase()
|
||||
let overallPrompt = (prompt + " " + negativePrompt).replaceAll(",", "").split(" ")
|
||||
let overallPrompt = (prompt + " " + negativePrompt)
|
||||
|
||||
let embeddingsTree = modelsOptions["embeddings"]
|
||||
let embeddings = []
|
||||
|
Loading…
Reference in New Issue
Block a user