mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 08:13:22 +01:00
Move the extraction logic for embeddings-from-prompt, from sdkit to ED's UI
This commit is contained in:
parent
b4cc21ea89
commit
ac1c65fba1
@ -65,9 +65,6 @@ def load_default_models(context: Context):
|
||||
|
||||
runtime.set_vram_optimizations(context)
|
||||
|
||||
config = app.getConfig()
|
||||
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
|
||||
|
||||
# init default model paths
|
||||
for model_type in MODELS_TO_LOAD_ON_START:
|
||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False)
|
||||
|
@ -73,6 +73,7 @@ class TaskData(BaseModel):
|
||||
use_hypernetwork_model: Union[str, List[str]] = None
|
||||
use_lora_model: Union[str, List[str]] = None
|
||||
use_controlnet_model: Union[str, List[str]] = None
|
||||
use_embeddings_model: Union[str, List[str]] = None
|
||||
filters: List[str] = []
|
||||
filter_params: Dict[str, Dict[str, Any]] = {}
|
||||
control_filter_to_apply: Union[str, List[str]] = None
|
||||
@ -200,6 +201,7 @@ def convert_legacy_render_req_to_new(old_req: dict):
|
||||
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
|
||||
model_paths["lora"] = old_req.get("use_lora_model")
|
||||
model_paths["controlnet"] = old_req.get("use_controlnet_model")
|
||||
model_paths["embeddings"] = old_req.get("use_embeddings_model")
|
||||
|
||||
model_paths["gfpgan"] = old_req.get("use_face_correction", "")
|
||||
model_paths["gfpgan"] = model_paths["gfpgan"] if "gfpgan" in model_paths["gfpgan"].lower() else None
|
||||
|
@ -845,6 +845,7 @@ function makeImage() {
|
||||
reqBody: Object.assign({ prompt: prompt }, taskTemplate.reqBody),
|
||||
})
|
||||
)
|
||||
newTaskRequests.forEach(setEmbeddings)
|
||||
newTaskRequests.forEach(createTask)
|
||||
|
||||
updateInitialText()
|
||||
@ -1493,6 +1494,42 @@ function getCurrentUserRequest() {
|
||||
return newTask
|
||||
}
|
||||
|
||||
function setEmbeddings(task) {
|
||||
let prompt = task.reqBody.prompt.toLowerCase()
|
||||
let negativePrompt = task.reqBody.negative_prompt.toLowerCase()
|
||||
let overallPrompt = (prompt + " " + negativePrompt).split(" ")
|
||||
|
||||
let embeddingsTree = modelsOptions["embeddings"]
|
||||
let embeddings = []
|
||||
function extract(entries, basePath = "") {
|
||||
entries.forEach((e) => {
|
||||
if (Array.isArray(e)) {
|
||||
let path = basePath === "" ? basePath + e[0] : basePath + "/" + e[0]
|
||||
extract(e[1], path)
|
||||
} else {
|
||||
let path = basePath === "" ? basePath + e : basePath + "/" + e
|
||||
embeddings.push([e.toLowerCase().replace(" ", "_"), path])
|
||||
}
|
||||
})
|
||||
}
|
||||
extract(embeddingsTree)
|
||||
|
||||
let embeddingPaths = []
|
||||
|
||||
embeddings.forEach((e) => {
|
||||
let token = e[0]
|
||||
let path = e[1]
|
||||
|
||||
if (overallPrompt.includes(token)) {
|
||||
embeddingPaths.push(path)
|
||||
}
|
||||
})
|
||||
|
||||
if (embeddingPaths.length > 0) {
|
||||
task.reqBody.use_embeddings_model = embeddingPaths
|
||||
}
|
||||
}
|
||||
|
||||
function getModelInfo(models) {
|
||||
let modelInfo = models.map((e) => [e[0].value, e[1].value])
|
||||
modelInfo = modelInfo.filter((e) => e[0].trim() !== "")
|
||||
|
Loading…
Reference in New Issue
Block a user