From 2111a81d18012f73a39ce5700d4c8ba5c58cf76e Mon Sep 17 00:00:00 2001 From: patriceac <48073125+patriceac@users.noreply.github.com> Date: Sat, 19 Nov 2022 00:56:44 -0800 Subject: [PATCH] Proper PR for VAE support in Use Settings --- ui/media/js/dnd.js | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 9cbbe8af..05f1b5dd 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -161,18 +161,7 @@ const TASK_MAPPING = { setUI: (use_stable_diffusion_model) => { const oldVal = stableDiffusionModelField.value - let pathIdx = use_stable_diffusion_model.lastIndexOf('/') // Linux, Mac paths - if (pathIdx < 0) { - pathIdx = use_stable_diffusion_model.lastIndexOf('\\') // Windows paths. - } - if (pathIdx >= 0) { - use_stable_diffusion_model = use_stable_diffusion_model.slice(pathIdx + 1) - } - const modelExt = '.ckpt' - if (use_stable_diffusion_model.endsWith(modelExt)) { - use_stable_diffusion_model = use_stable_diffusion_model.slice(0, use_stable_diffusion_model.length - modelExt.length) - } - + use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt']) stableDiffusionModelField.value = use_stable_diffusion_model if (!stableDiffusionModelField.value) { @@ -182,6 +171,19 @@ const TASK_MAPPING = { readUI: () => stableDiffusionModelField.value, parse: (val) => val }, + use_vae_model: { name: 'VAE model', + setUI: (use_vae_model) => { + const oldVal = vaeModelField.value + + if (use_vae_model !== '') { + use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt']) + use_vae_model = use_vae_model !== '' ? use_vae_model : oldVal + } + vaeModelField.value = use_vae_model + }, + readUI: () => vaeModelField.value, + parse: (val) => val + }, numOutputsParallel: { name: 'Parallel Images', setUI: (numOutputsParallel) => { @@ -310,6 +312,21 @@ function readUI() { 'reqBody': reqBody } } +function getModelPath(filename, extensions) +{ + let pathIdx = filename.lastIndexOf('/') // Linux, Mac paths + if (pathIdx < 0) { + pathIdx = filename.lastIndexOf('\\') // Windows paths. + } + if (pathIdx >= 0) { + filename = filename.slice(pathIdx + 1) + } + extensions.forEach(ext => { + if (filename.endsWith(ext)) { + filename = filename.slice(0, filename.length - ext.length) + } + }) +} const TASK_TEXT_MAPPING = { width: 'Width',