Proper PR for VAE support in Use Settings

This commit is contained in:
patriceac 2022-11-19 00:56:44 -08:00 committed by GitHub
parent 6799b3d7da
commit 2111a81d18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -161,18 +161,7 @@ const TASK_MAPPING = {
setUI: (use_stable_diffusion_model) => { setUI: (use_stable_diffusion_model) => {
const oldVal = stableDiffusionModelField.value const oldVal = stableDiffusionModelField.value
let pathIdx = use_stable_diffusion_model.lastIndexOf('/') // Linux, Mac paths use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt'])
if (pathIdx < 0) {
pathIdx = use_stable_diffusion_model.lastIndexOf('\\') // Windows paths.
}
if (pathIdx >= 0) {
use_stable_diffusion_model = use_stable_diffusion_model.slice(pathIdx + 1)
}
const modelExt = '.ckpt'
if (use_stable_diffusion_model.endsWith(modelExt)) {
use_stable_diffusion_model = use_stable_diffusion_model.slice(0, use_stable_diffusion_model.length - modelExt.length)
}
stableDiffusionModelField.value = use_stable_diffusion_model stableDiffusionModelField.value = use_stable_diffusion_model
if (!stableDiffusionModelField.value) { if (!stableDiffusionModelField.value) {
@ -182,6 +171,19 @@ const TASK_MAPPING = {
readUI: () => stableDiffusionModelField.value, readUI: () => stableDiffusionModelField.value,
parse: (val) => val parse: (val) => val
}, },
use_vae_model: { name: 'VAE model',
setUI: (use_vae_model) => {
const oldVal = vaeModelField.value
if (use_vae_model !== '') {
use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt'])
use_vae_model = use_vae_model !== '' ? use_vae_model : oldVal
}
vaeModelField.value = use_vae_model
},
readUI: () => vaeModelField.value,
parse: (val) => val
},
numOutputsParallel: { name: 'Parallel Images', numOutputsParallel: { name: 'Parallel Images',
setUI: (numOutputsParallel) => { setUI: (numOutputsParallel) => {
@ -310,6 +312,21 @@ function readUI() {
'reqBody': reqBody 'reqBody': reqBody
} }
} }
function getModelPath(filename, extensions)
{
let pathIdx = filename.lastIndexOf('/') // Linux, Mac paths
if (pathIdx < 0) {
pathIdx = filename.lastIndexOf('\\') // Windows paths.
}
if (pathIdx >= 0) {
filename = filename.slice(pathIdx + 1)
}
extensions.forEach(ext => {
if (filename.endsWith(ext)) {
filename = filename.slice(0, filename.length - ext.length)
}
})
}
const TASK_TEXT_MAPPING = { const TASK_TEXT_MAPPING = {
width: 'Width', width: 'Width',