From 67cae9725e066044b5adb7261765687cfc14566c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 18 Aug 2023 14:16:23 +0530 Subject: [PATCH] Fix drag-and-drop and Use these Settings for LoRA --- ui/media/js/dnd.js | 78 ++++++++++++----------------- ui/media/js/multi-model-selector.js | 45 +++++++++++++++++ 2 files changed, 76 insertions(+), 47 deletions(-) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 3cca985d..7baa27b3 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -292,39 +292,36 @@ const TASK_MAPPING = { use_lora_model: { name: "LoRA model", setUI: (use_lora_model) => { - // create rows - for (let i = loraModels.length; i < use_lora_model.length; i++) { - createLoraEntry() - } - - use_lora_model.forEach((model_name, i) => { - let field = loraModels[i][0] - const oldVal = field.value - - if (model_name !== "") { - model_name = getModelPath(model_name, [".ckpt", ".safetensors"]) - model_name = model_name !== "" ? model_name : oldVal + let modelPaths = [] + use_lora_model.forEach((m) => { + if (m.includes("models\\lora\\")) { + m = m.split("models\\lora\\")[1] + } else if (m.includes("models\\\\lora\\\\")) { + m = m.split("models\\\\lora\\\\")[1] + } else if (m.includes("models/lora/")) { + m = m.split("models/lora/")[1] } - field.value = model_name + m = m.replaceAll("\\\\", "/") + m = getModelPath(m, [".ckpt", ".safetensors"]) + modelPaths.push(m) }) - - // clear the remaining entries - let container = document.querySelector("#lora_model_container .model_entries") - for (let i = use_lora_model.length; i < loraModels.length; i++) { - let modelEntry = loraModels[i][2] - container.removeChild(modelEntry) - } - - loraModels.splice(use_lora_model.length) + loraModelField.modelNames = modelPaths }, readUI: () => { - let values = loraModels.map((e) => e[0].value) - values = values.filter((e) => e.trim() !== "") - values = values.length > 0 ? values : "None" - return values + return loraModelField.modelNames }, parse: (val) => { val = !val || val === "None" ? "" : val + if (typeof val === "string" && val.includes(",")) { + val = val.split(",") + val = val.map((v) => v.trim()) + val = val.map((v) => v.replaceAll("\\", "\\\\")) + val = val.map((v) => v.replaceAll('"', "")) + val = val.map((v) => v.replaceAll("'", "")) + val = val.map((v) => '"' + v + '"') + val = "[" + val + "]" + val = JSON.parse(val) + } val = Array.isArray(val) ? val : [val] return val }, @@ -332,31 +329,16 @@ const TASK_MAPPING = { lora_alpha: { name: "LoRA Strength", setUI: (lora_alpha) => { - for (let i = loraModels.length; i < lora_alpha.length; i++) { - createLoraEntry() - } - - lora_alpha.forEach((model_strength, i) => { - let field = loraModels[i][1] - field.value = model_strength - }) - - // clear the remaining entries - let container = document.querySelector("#lora_model_container .model_entries") - for (let i = lora_alpha.length; i < loraModels.length; i++) { - let modelEntry = loraModels[i][2] - container.removeChild(modelEntry) - } - - loraModels.splice(lora_alpha.length) + loraModelField.modelWeights = lora_alpha }, readUI: () => { - let models = loraModels.filter((e) => e[0].value.trim() !== "") - let values = models.map((e) => e[1].value) - values = values.length > 0 ? values : 0 - return values + return loraModelField.modelWeights }, parse: (val) => { + if (typeof val === "string" && val.includes(",")) { + val = "[" + val.replaceAll("'", '"') + "]" + val = JSON.parse(val) + } val = Array.isArray(val) ? val : [val] val = val.map((e) => parseFloat(e)) return val @@ -569,6 +551,8 @@ const TASK_TEXT_MAPPING = { use_stable_diffusion_model: "Stable Diffusion model", use_hypernetwork_model: "Hypernetwork model", hypernetwork_strength: "Hypernetwork Strength", + use_lora_model: "LoRA model", + lora_alpha: "LoRA Strength", } function parseTaskFromText(str) { const taskReqBody = {} diff --git a/ui/media/js/multi-model-selector.js b/ui/media/js/multi-model-selector.js index 24d5e0f1..5b719b88 100644 --- a/ui/media/js/multi-model-selector.js +++ b/ui/media/js/multi-model-selector.js @@ -215,4 +215,49 @@ class MultiModelSelector { get length() { return this.modelContainer.childElementCount } + + get modelNames() { + return this.modelElements.map((e) => e.name.value) + } + + set modelNames(newModelNames) { + this.resizeEntryList(newModelNames.length) + + // assign to the corresponding elements + let currElements = this.modelElements + for (let i = 0; i < newModelNames.length; i++) { + let curr = currElements[i] + + curr.name.value = newModelNames[i] + } + } + + get modelWeights() { + return this.modelElements.map((e) => e.weight.value) + } + + set modelWeights(newModelWeights) { + this.resizeEntryList(newModelWeights.length) + + // assign to the corresponding elements + let currElements = this.modelElements + for (let i = 0; i < newModelWeights.length; i++) { + let curr = currElements[i] + + curr.weight.value = newModelWeights[i] + } + } + + resizeEntryList(newLength) { + let currLength = this.length + if (currLength < newLength) { + for (let i = currLength; i < newLength; i++) { + this.addModelEntry() + } + } else { + for (let i = newLength; i < currLength; i++) { + this.removeModelEntry() + } + } + } }