From ac1c65fba11c3d79e64f42509c51372a76364285 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 17 Aug 2023 10:54:47 +0530 Subject: [PATCH] Move the extraction logic for embeddings-from-prompt, from sdkit to ED's UI --- ui/easydiffusion/model_manager.py | 3 --- ui/easydiffusion/types.py | 2 ++ ui/media/js/main.js | 37 +++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 845e9126..841b0cd2 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -65,9 +65,6 @@ def load_default_models(context: Context): runtime.set_vram_optimizations(context) - config = app.getConfig() - context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings") - # init default model paths for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False) diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index fe936ca2..eeffcb72 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -73,6 +73,7 @@ class TaskData(BaseModel): use_hypernetwork_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None use_controlnet_model: Union[str, List[str]] = None + use_embeddings_model: Union[str, List[str]] = None filters: List[str] = [] filter_params: Dict[str, Dict[str, Any]] = {} control_filter_to_apply: Union[str, List[str]] = None @@ -200,6 +201,7 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model") model_paths["lora"] = old_req.get("use_lora_model") model_paths["controlnet"] = old_req.get("use_controlnet_model") + model_paths["embeddings"] = old_req.get("use_embeddings_model") model_paths["gfpgan"] = old_req.get("use_face_correction", "") model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None diff --git a/ui/media/js/main.js b/ui/media/js/main.js index ea2af7e3..c29850c2 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -845,6 +845,7 @@ function makeImage() { reqBody: Object.assign({ prompt: prompt }, taskTemplate.reqBody), }) ) + newTaskRequests.forEach(setEmbeddings) newTaskRequests.forEach(createTask) updateInitialText() @@ -1493,6 +1494,42 @@ function getCurrentUserRequest() { return newTask } +function setEmbeddings(task) { + let prompt = task.reqBody.prompt.toLowerCase() + let negativePrompt = task.reqBody.negative_prompt.toLowerCase() + let overallPrompt = (prompt + " " + negativePrompt).split(" ") + + let embeddingsTree = modelsOptions["embeddings"] + let embeddings = [] + function extract(entries, basePath = "") { + entries.forEach((e) => { + if (Array.isArray(e)) { + let path = basePath === "" ? basePath + e[0] : basePath + "/" + e[0] + extract(e[1], path) + } else { + let path = basePath === "" ? basePath + e : basePath + "/" + e + embeddings.push([e.toLowerCase().replace(" ", "_"), path]) + } + }) + } + extract(embeddingsTree) + + let embeddingPaths = [] + + embeddings.forEach((e) => { + let token = e[0] + let path = e[1] + + if (overallPrompt.includes(token)) { + embeddingPaths.push(path) + } + }) + + if (embeddingPaths.length > 0) { + task.reqBody.use_embeddings_model = embeddingPaths + } +} + function getModelInfo(models) { let modelInfo = models.map((e) => [e[0].value, e[1].value]) modelInfo = modelInfo.filter((e) => e[0].trim() !== "")