Support lora models in subfolders when scanning the <lora> tag (#1521)

* Recursive lora search

* Support lora models in subfolders when scanning the <lora> tag
This commit is contained in:
JeLuF 2023-08-29 07:18:57 +02:00 committed by GitHub
parent e49772030d
commit b89d152540
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 13 deletions

View File

@ -678,6 +678,29 @@ function getAllModelNames(type) {
return f(modelsOptions[type]) return f(modelsOptions[type])
} }
// gets a flattened list of all models of a certain type. e.g. "path/subpath/modelname"
// use the filter to search for all models having a certain name.
function getAllModelPathes(type,filter="") {
function f(tree, prefix) {
if (tree == undefined) {
return []
}
let result = []
tree.forEach((e) => {
if (typeof e == "object") {
result = result.concat(f(e[1], prefix + e[0] + "/"))
} else {
if (filter=="" || e==filter) {
result.push(prefix + e)
}
}
})
return result
}
return f(modelsOptions[type], "")
}
function onUseAsThumbnailClick(req, img) { function onUseAsThumbnailClick(req, img) {
let scale = 1 let scale = 1
let targetWidth = img.naturalWidth let targetWidth = img.naturalWidth

View File

@ -29,22 +29,13 @@
let modelWeights = LoRA.map(e => e.lora_alpha_0) let modelWeights = LoRA.map(e => e.lora_alpha_0)
loraModelField.value = {modelNames: modelNames, modelWeights: modelWeights} loraModelField.value = {modelNames: modelNames, modelWeights: modelWeights}
showToast("Prompt successfully processed", LoRA[0].lora_model_0); showToast("Prompt successfully processed")
} }
//promptField.dispatchEvent(new Event('change')); //promptField.dispatchEvent(new Event('change'));
}); });
function isModelAvailable(array, searchString) {
const foundItem = array.find(function(item) {
item = item.toString().toLowerCase();
return item === searchString.toLowerCase()
});
return foundItem || "";
}
// extract LoRA tags from strings // extract LoRA tags from strings
function extractLoraTags(prompt) { function extractLoraTags(prompt) {
// Define the regular expression for the tags // Define the regular expression for the tags
@ -55,11 +46,13 @@
// Iterate over the string, finding matches // Iterate over the string, finding matches
for (const match of prompt.matchAll(regex)) { for (const match of prompt.matchAll(regex)) {
const modelFileName = isModelAvailable(modelsCache.options.lora, match[1].trim()) const modelFileName = match[1].trim()
if (modelFileName !== "") { const loraPathes = getAllModelPathes("lora", modelFileName)
if (loraPathes.length > 0) {
const loraPath = loraPathes[0]
// Initialize an object to hold a match // Initialize an object to hold a match
let loraTag = { let loraTag = {
lora_model_0: modelFileName, lora_model_0: loraPath,
} }
//console.log("Model:" + modelFileName); //console.log("Model:" + modelFileName);