diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 9cbbe8af..05f1b5dd 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -161,18 +161,7 @@ const TASK_MAPPING = { setUI: (use_stable_diffusion_model) => { const oldVal = stableDiffusionModelField.value - let pathIdx = use_stable_diffusion_model.lastIndexOf('/') // Linux, Mac paths - if (pathIdx < 0) { - pathIdx = use_stable_diffusion_model.lastIndexOf('\\') // Windows paths. - } - if (pathIdx >= 0) { - use_stable_diffusion_model = use_stable_diffusion_model.slice(pathIdx + 1) - } - const modelExt = '.ckpt' - if (use_stable_diffusion_model.endsWith(modelExt)) { - use_stable_diffusion_model = use_stable_diffusion_model.slice(0, use_stable_diffusion_model.length - modelExt.length) - } - + use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt']) stableDiffusionModelField.value = use_stable_diffusion_model if (!stableDiffusionModelField.value) { @@ -182,6 +171,19 @@ const TASK_MAPPING = { readUI: () => stableDiffusionModelField.value, parse: (val) => val }, + use_vae_model: { name: 'VAE model', + setUI: (use_vae_model) => { + const oldVal = vaeModelField.value + + if (use_vae_model !== '') { + use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt']) + use_vae_model = use_vae_model !== '' ? use_vae_model : oldVal + } + vaeModelField.value = use_vae_model + }, + readUI: () => vaeModelField.value, + parse: (val) => val + }, numOutputsParallel: { name: 'Parallel Images', setUI: (numOutputsParallel) => { @@ -310,6 +312,21 @@ function readUI() { 'reqBody': reqBody } } +function getModelPath(filename, extensions) +{ + let pathIdx = filename.lastIndexOf('/') // Linux, Mac paths + if (pathIdx < 0) { + pathIdx = filename.lastIndexOf('\\') // Windows paths. + } + if (pathIdx >= 0) { + filename = filename.slice(pathIdx + 1) + } + extensions.forEach(ext => { + if (filename.endsWith(ext)) { + filename = filename.slice(0, filename.length - ext.length) + } + }) +} const TASK_TEXT_MAPPING = { width: 'Width',