Move the extraction logic for embeddings-from-prompt, from sdkit to ED's UI

This commit is contained in:
cmdr2 2023-08-17 10:54:47 +05:30
parent b4cc21ea89
commit ac1c65fba1
3 changed files with 39 additions and 3 deletions

View File

@ -65,9 +65,6 @@ def load_default_models(context: Context):
runtime.set_vram_optimizations(context) runtime.set_vram_optimizations(context)
config = app.getConfig()
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
# init default model paths # init default model paths
for model_type in MODELS_TO_LOAD_ON_START: for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False) context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False)

View File

@ -73,6 +73,7 @@ class TaskData(BaseModel):
use_hypernetwork_model: Union[str, List[str]] = None use_hypernetwork_model: Union[str, List[str]] = None
use_lora_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None
use_controlnet_model: Union[str, List[str]] = None use_controlnet_model: Union[str, List[str]] = None
use_embeddings_model: Union[str, List[str]] = None
filters: List[str] = [] filters: List[str] = []
filter_params: Dict[str, Dict[str, Any]] = {} filter_params: Dict[str, Dict[str, Any]] = {}
control_filter_to_apply: Union[str, List[str]] = None control_filter_to_apply: Union[str, List[str]] = None
@ -200,6 +201,7 @@ def convert_legacy_render_req_to_new(old_req: dict):
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model") model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
model_paths["lora"] = old_req.get("use_lora_model") model_paths["lora"] = old_req.get("use_lora_model")
model_paths["controlnet"] = old_req.get("use_controlnet_model") model_paths["controlnet"] = old_req.get("use_controlnet_model")
model_paths["embeddings"] = old_req.get("use_embeddings_model")
model_paths["gfpgan"] = old_req.get("use_face_correction", "") model_paths["gfpgan"] = old_req.get("use_face_correction", "")
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None

View File

@ -845,6 +845,7 @@ function makeImage() {
reqBody: Object.assign({ prompt: prompt }, taskTemplate.reqBody), reqBody: Object.assign({ prompt: prompt }, taskTemplate.reqBody),
}) })
) )
newTaskRequests.forEach(setEmbeddings)
newTaskRequests.forEach(createTask) newTaskRequests.forEach(createTask)
updateInitialText() updateInitialText()
@ -1493,6 +1494,42 @@ function getCurrentUserRequest() {
return newTask return newTask
} }
function setEmbeddings(task) {
let prompt = task.reqBody.prompt.toLowerCase()
let negativePrompt = task.reqBody.negative_prompt.toLowerCase()
let overallPrompt = (prompt + " " + negativePrompt).split(" ")
let embeddingsTree = modelsOptions["embeddings"]
let embeddings = []
function extract(entries, basePath = "") {
entries.forEach((e) => {
if (Array.isArray(e)) {
let path = basePath === "" ? basePath + e[0] : basePath + "/" + e[0]
extract(e[1], path)
} else {
let path = basePath === "" ? basePath + e : basePath + "/" + e
embeddings.push([e.toLowerCase().replace(" ", "_"), path])
}
})
}
extract(embeddingsTree)
let embeddingPaths = []
embeddings.forEach((e) => {
let token = e[0]
let path = e[1]
if (overallPrompt.includes(token)) {
embeddingPaths.push(path)
}
})
if (embeddingPaths.length > 0) {
task.reqBody.use_embeddings_model = embeddingPaths
}
}
function getModelInfo(models) { function getModelInfo(models) {
let modelInfo = models.map((e) => [e[0].value, e[1].value]) let modelInfo = models.map((e) => [e[0].value, e[1].value])
modelInfo = modelInfo.filter((e) => e[0].trim() !== "") modelInfo = modelInfo.filter((e) => e[0].trim() !== "")