mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-24 03:18:29 +02:00
Support custom VAE files; Use vae-ft-mse-840000-ema-pruned as the default VAE, which can be overridden by putting a .vae.pt file inside models/stable-diffusion with the same name as the ckpt model file. The UI / System Settings allows setting the default VAE model to use
This commit is contained in:
parent
79a7cd2938
commit
a8c16e39b8
@ -321,6 +321,36 @@ echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@if exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
|
||||||
|
for %%I in ("..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt") do if "%%~zI" EQU "334695179" (
|
||||||
|
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
|
||||||
|
) else (
|
||||||
|
echo. & echo "The default VAE (sd-vae-ft-mse-original) file present at models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt is invalid. It is only %%~zI bytes in size. Re-downloading.." & echo.
|
||||||
|
del "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@if not exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
|
||||||
|
@echo. & echo "Downloading data files (weights) for the default VAE (sd-vae-ft-mse-original).." & echo.
|
||||||
|
|
||||||
|
@call curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt
|
||||||
|
|
||||||
|
@if exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
|
||||||
|
for %%I in ("..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt") do if "%%~zI" NEQ "334695179" (
|
||||||
|
echo. & echo "Error: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: %%~zI" & echo.
|
||||||
|
echo. & echo "Error downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
|
||||||
|
pause
|
||||||
|
exit /b
|
||||||
|
)
|
||||||
|
) else (
|
||||||
|
@echo. & echo "Error downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
|
||||||
|
pause
|
||||||
|
exit /b
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
||||||
@if "%ERRORLEVEL%" NEQ "0" (
|
@if "%ERRORLEVEL%" NEQ "0" (
|
||||||
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
|
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
|
||||||
|
@ -300,6 +300,38 @@ if [ ! -f "RealESRGAN_x4plus_anime_6B.pth" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
|
||||||
|
model_size=`find ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt -printf "%s"`
|
||||||
|
|
||||||
|
if [ "$model_size" -eq "334695179" ]; then
|
||||||
|
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
|
||||||
|
else
|
||||||
|
printf "\n\nThe model file present at models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt is invalid. It is only $model_size bytes in size. Re-downloading.."
|
||||||
|
rm ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
|
||||||
|
echo "Downloading data files (weights) for the default VAE (sd-vae-ft-mse-original).."
|
||||||
|
|
||||||
|
curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt
|
||||||
|
|
||||||
|
if [ -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
|
||||||
|
model_size=`find ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt -printf "%s"`
|
||||||
|
if [ ! "$model_size" -eq "334695179" ]; then
|
||||||
|
printf "\n\nError: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: $model_size\n\n"
|
||||||
|
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
|
||||||
|
read -p "Press any key to continue"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
|
||||||
|
read -p "Press any key to continue"
|
||||||
|
exit
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||||
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
||||||
echo sd_install_complete >> ../scripts/install_status.txt
|
echo sd_install_complete >> ../scripts/install_status.txt
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
<div id="container">
|
<div id="container">
|
||||||
<div id="top-nav">
|
<div id="top-nav">
|
||||||
<div id="logo">
|
<div id="logo">
|
||||||
<h1>Stable Diffusion UI <small>v2.3.7 <span id="updateBranchLabel"></span></small></h1>
|
<h1>Stable Diffusion UI <small>v2.3.8 <span id="updateBranchLabel"></span></small></h1>
|
||||||
</div>
|
</div>
|
||||||
<ul id="top-nav-items">
|
<ul id="top-nav-items">
|
||||||
<li class="dropdown">
|
<li class="dropdown">
|
||||||
@ -38,6 +38,12 @@
|
|||||||
<br/>
|
<br/>
|
||||||
<li><label for="theme">Theme: </label><select id="theme" name="theme"><option value="theme-default">Default</option></select></li>
|
<li><label for="theme">Theme: </label><select id="theme" name="theme"><option value="theme-default">Default</option></select></li>
|
||||||
<li><input id="save_to_disk" name="save_to_disk" type="checkbox"> <label for="save_to_disk">Automatically save to <input id="diskPath" name="diskPath" size="40" disabled></label></li>
|
<li><input id="save_to_disk" name="save_to_disk" type="checkbox"> <label for="save_to_disk">Automatically save to <input id="diskPath" name="diskPath" size="40" disabled></label></li>
|
||||||
|
<li>
|
||||||
|
<label for="default_vae_model">Default VAE:</label></td><td>
|
||||||
|
<select id="default_vae_model" name="default_vae_model">
|
||||||
|
<!-- <option value="vae-ft-mse-840000-ema-pruned" selected>vae-ft-mse-840000-ema-pruned</option> -->
|
||||||
|
</select>
|
||||||
|
</li>
|
||||||
<li><input id="sound_toggle" name="sound_toggle" type="checkbox" checked> <label for="sound_toggle">Play sound on task completion</label></li>
|
<li><input id="sound_toggle" name="sound_toggle" type="checkbox" checked> <label for="sound_toggle">Play sound on task completion</label></li>
|
||||||
<li><input id="turbo" name="turbo" type="checkbox" checked> <label for="turbo">Turbo mode <small>(generates images faster, but uses an additional 1 GB of GPU memory)</small></label></li>
|
<li><input id="turbo" name="turbo" type="checkbox" checked> <label for="turbo">Turbo mode <small>(generates images faster, but uses an additional 1 GB of GPU memory)</small></label></li>
|
||||||
<li><input id="use_cpu" name="use_cpu" type="checkbox"> <label for="use_cpu">Use CPU instead of GPU <small>(warning: this will be *very* slow)</small></label></li>
|
<li><input id="use_cpu" name="use_cpu" type="checkbox"> <label for="use_cpu">Use CPU instead of GPU <small>(warning: this will be *very* slow)</small></label></li>
|
||||||
@ -273,7 +279,7 @@
|
|||||||
<script src="media/js/inpainting-editor.js?v=1"></script>
|
<script src="media/js/inpainting-editor.js?v=1"></script>
|
||||||
<script src="media/js/image-modifiers.js?v=3"></script>
|
<script src="media/js/image-modifiers.js?v=3"></script>
|
||||||
<script src="media/js/auto-save.js?v=2"></script>
|
<script src="media/js/auto-save.js?v=2"></script>
|
||||||
<script src="media/js/main.js?v=5"></script>
|
<script src="media/js/main.js?v=6"></script>
|
||||||
<script src="media/js/themes.js?v=2"></script>
|
<script src="media/js/themes.js?v=2"></script>
|
||||||
<script>
|
<script>
|
||||||
async function init() {
|
async function init() {
|
||||||
|
@ -39,6 +39,7 @@ let useFaceCorrectionField = document.querySelector("#use_face_correction")
|
|||||||
let useUpscalingField = document.querySelector("#use_upscale")
|
let useUpscalingField = document.querySelector("#use_upscale")
|
||||||
let upscaleModelField = document.querySelector("#upscale_model")
|
let upscaleModelField = document.querySelector("#upscale_model")
|
||||||
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
||||||
|
let vaeModelField = document.querySelector('#default_vae_model')
|
||||||
let outputFormatField = document.querySelector('#output_format')
|
let outputFormatField = document.querySelector('#output_format')
|
||||||
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
||||||
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
||||||
@ -1104,15 +1105,13 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
|
|||||||
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
||||||
updatePromptStrength()
|
updatePromptStrength()
|
||||||
|
|
||||||
useBetaChannelField.addEventListener('click', async function(e) {
|
async function changeAppConfig(configDelta) {
|
||||||
if (!isServerAvailable()) {
|
// if (!isServerAvailable()) {
|
||||||
// logError('The server is still starting up..')
|
// // logError('The server is still starting up..')
|
||||||
alert('The server is still starting up..')
|
// alert('The server is still starting up..')
|
||||||
e.preventDefault()
|
// e.preventDefault()
|
||||||
return false
|
// return false
|
||||||
}
|
// }
|
||||||
|
|
||||||
let updateBranch = (this.checked ? 'beta' : 'main')
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let res = await fetch('/app_config', {
|
let res = await fetch('/app_config', {
|
||||||
@ -1120,9 +1119,7 @@ useBetaChannelField.addEventListener('click', async function(e) {
|
|||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(configDelta)
|
||||||
'update_branch': updateBranch
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
res = await res.json()
|
res = await res.json()
|
||||||
|
|
||||||
@ -1130,6 +1127,20 @@ useBetaChannelField.addEventListener('click', async function(e) {
|
|||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('set config status error', e)
|
console.log('set config status error', e)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useBetaChannelField.addEventListener('click', async function(e) {
|
||||||
|
let updateBranch = (this.checked ? 'beta' : 'main')
|
||||||
|
|
||||||
|
await changeAppConfig({
|
||||||
|
'update_branch': updateBranch
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
vaeModelField.addEventListener('change', async function() {
|
||||||
|
await changeAppConfig({
|
||||||
|
'model_vae': this.value
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
async function getAppConfig() {
|
async function getAppConfig() {
|
||||||
@ -1155,11 +1166,14 @@ async function getModels() {
|
|||||||
let res = await fetch('/get/models')
|
let res = await fetch('/get/models')
|
||||||
const models = await res.json()
|
const models = await res.json()
|
||||||
|
|
||||||
// let activeModel = models['active']
|
let activeModels = models['active']
|
||||||
let modelOptions = models['options']
|
let modelOptions = models['options']
|
||||||
let stableDiffusionOptions = modelOptions['stable-diffusion']
|
let stableDiffusionOptions = modelOptions['stable-diffusion']
|
||||||
|
let vaeOptions = modelOptions['vae']
|
||||||
|
let activeVae = activeModels['vae']
|
||||||
|
|
||||||
stableDiffusionOptions.forEach(modelName => {
|
function createModelOptions(modelField, selectedModel) {
|
||||||
|
return function(modelName) {
|
||||||
let modelOption = document.createElement('option')
|
let modelOption = document.createElement('option')
|
||||||
modelOption.value = modelName
|
modelOption.value = modelName
|
||||||
modelOption.innerText = modelName
|
modelOption.innerText = modelName
|
||||||
@ -1168,8 +1182,12 @@ async function getModels() {
|
|||||||
modelOption.selected = true
|
modelOption.selected = true
|
||||||
}
|
}
|
||||||
|
|
||||||
stableDiffusionModelField.appendChild(modelOption)
|
modelField.appendChild(modelOption)
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedModel))
|
||||||
|
vaeOptions.forEach(createModelOptions(vaeModelField, activeVae))
|
||||||
|
|
||||||
// TODO: set default for model here too
|
// TODO: set default for model here too
|
||||||
SETTINGS[model_setting_key].default = stableDiffusionOptions[0]
|
SETTINGS[model_setting_key].default = stableDiffusionOptions[0]
|
||||||
|
@ -23,6 +23,7 @@ class Request:
|
|||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
use_face_correction: str = None # or "GFPGANv1.3"
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
|
use_vae_model: str = None
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
|
|
||||||
@ -45,6 +46,7 @@ class Request:
|
|||||||
"use_face_correction": self.use_face_correction,
|
"use_face_correction": self.use_face_correction,
|
||||||
"use_upscale": self.use_upscale,
|
"use_upscale": self.use_upscale,
|
||||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||||
|
"use_vae_model": self.use_vae_model,
|
||||||
"output_format": self.output_format,
|
"output_format": self.output_format,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,6 +69,7 @@ class Request:
|
|||||||
use_face_correction: {self.use_face_correction}
|
use_face_correction: {self.use_face_correction}
|
||||||
use_upscale: {self.use_upscale}
|
use_upscale: {self.use_upscale}
|
||||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||||
|
use_vae_model: {self.use_vae_model}
|
||||||
show_only_filtered_image: {self.show_only_filtered_image}
|
show_only_filtered_image: {self.show_only_filtered_image}
|
||||||
output_format: {self.output_format}
|
output_format: {self.output_format}
|
||||||
|
|
||||||
|
@ -87,6 +87,7 @@ def device_init(device_selection=None):
|
|||||||
thread_data.temp_images = {}
|
thread_data.temp_images = {}
|
||||||
|
|
||||||
thread_data.ckpt_file = None
|
thread_data.ckpt_file = None
|
||||||
|
thread_data.vae_file = None
|
||||||
thread_data.gfpgan_file = None
|
thread_data.gfpgan_file = None
|
||||||
thread_data.real_esrgan_file = None
|
thread_data.real_esrgan_file = None
|
||||||
|
|
||||||
@ -184,7 +185,7 @@ def load_model_ckpt():
|
|||||||
if thread_data.device == 'cpu':
|
if thread_data.device == 'cpu':
|
||||||
thread_data.precision = 'full'
|
thread_data.precision = 'full'
|
||||||
|
|
||||||
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
|
print('loading', thread_data.ckpt_file + '.ckpt', 'to', thread_data.device, 'using precision', thread_data.precision)
|
||||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||||
li, lo = [], []
|
li, lo = [], []
|
||||||
for key, value in sd.items():
|
for key, value in sd.items():
|
||||||
@ -231,6 +232,16 @@ def load_model_ckpt():
|
|||||||
|
|
||||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
|
if thread_data.vae_file is not None:
|
||||||
|
if os.path.exists(thread_data.vae_file + '.vae.pt'):
|
||||||
|
print(f"Loading VAE weights from: {thread_data.vae_file}.vae.pt")
|
||||||
|
vae_ckpt = torch.load(thread_data.vae_file + '.vae.pt', map_location="cpu")
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||||
|
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
|
else:
|
||||||
|
print(f'Cannot find VAE file: {thread_data.vae_file}.vae.pt')
|
||||||
|
|
||||||
modelFS.eval()
|
modelFS.eval()
|
||||||
if thread_data.device != 'cpu':
|
if thread_data.device != 'cpu':
|
||||||
if thread_data.reduced_memory:
|
if thread_data.reduced_memory:
|
||||||
@ -459,8 +470,9 @@ def do_mk_img(req: Request):
|
|||||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||||
|
|
||||||
needs_model_reload = False
|
needs_model_reload = False
|
||||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
|
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||||
|
thread_data.vae_file = req.use_vae_model
|
||||||
needs_model_reload = True
|
needs_model_reload = True
|
||||||
|
|
||||||
if thread_data.has_valid_gpu:
|
if thread_data.has_valid_gpu:
|
||||||
|
@ -73,6 +73,7 @@ class ImageRequest(BaseModel):
|
|||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
use_face_correction: str = None # or "GFPGANv1.3"
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
|
use_vae_model: str = None
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
|
|
||||||
@ -170,27 +171,34 @@ render_threads = []
|
|||||||
current_state = ServerStates.Init
|
current_state = ServerStates.Init
|
||||||
current_state_error:Exception = None
|
current_state_error:Exception = None
|
||||||
current_model_path = None
|
current_model_path = None
|
||||||
|
current_vae_path = None
|
||||||
tasks_queue = []
|
tasks_queue = []
|
||||||
task_cache = TaskCache()
|
task_cache = TaskCache()
|
||||||
default_model_to_load = None
|
default_model_to_load = None
|
||||||
|
default_vae_to_load = None
|
||||||
weak_thread_data = weakref.WeakKeyDictionary()
|
weak_thread_data = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
def preload_model(file_path=None):
|
def preload_model(ckpt_file_path=None, vae_file_path=None):
|
||||||
global current_state, current_state_error, current_model_path
|
global current_state, current_state_error, current_model_path
|
||||||
if file_path == None:
|
if ckpt_file_path == None:
|
||||||
file_path = default_model_to_load
|
ckpt_file_path = default_model_to_load
|
||||||
if file_path == current_model_path:
|
if vae_file_path == None:
|
||||||
|
vae_file_path = default_vae_to_load
|
||||||
|
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
||||||
return
|
return
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
try:
|
try:
|
||||||
from . import runtime
|
from . import runtime
|
||||||
runtime.thread_data.ckpt_file = file_path
|
runtime.thread_data.ckpt_file = ckpt_file_path
|
||||||
|
runtime.thread_data.vae_file = vae_file_path
|
||||||
runtime.load_model_ckpt()
|
runtime.load_model_ckpt()
|
||||||
current_model_path = file_path
|
current_model_path = ckpt_file_path
|
||||||
|
current_vae_path = vae_file_path
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
current_model_path = None
|
current_model_path = None
|
||||||
|
current_vae_path = None
|
||||||
current_state_error = e
|
current_state_error = e
|
||||||
current_state = ServerStates.Unavailable
|
current_state = ServerStates.Unavailable
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -240,7 +248,7 @@ def thread_get_next_task():
|
|||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
|
|
||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error, current_model_path
|
global current_state, current_state_error, current_model_path, current_vae_path
|
||||||
from . import runtime
|
from . import runtime
|
||||||
try:
|
try:
|
||||||
runtime.device_init(device)
|
runtime.device_init(device)
|
||||||
@ -289,6 +297,7 @@ def thread_render(device):
|
|||||||
if current_state == ServerStates.LoadingModel:
|
if current_state == ServerStates.LoadingModel:
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
current_model_path = task.request.use_stable_diffusion_model
|
current_model_path = task.request.use_stable_diffusion_model
|
||||||
|
current_vae_path = task.request.use_vae_model
|
||||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||||
runtime.thread_data.stop_processing = True
|
runtime.thread_data.stop_processing = True
|
||||||
if isinstance(current_state_error, StopAsyncIteration):
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
@ -411,6 +420,7 @@ def render(req : ImageRequest):
|
|||||||
r.use_upscale: str = req.use_upscale
|
r.use_upscale: str = req.use_upscale
|
||||||
r.use_face_correction = req.use_face_correction
|
r.use_face_correction = req.use_face_correction
|
||||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
||||||
|
r.use_vae_model = req.use_vae_model
|
||||||
r.show_only_filtered_image = req.show_only_filtered_image
|
r.show_only_filtered_image = req.show_only_filtered_image
|
||||||
r.output_format = req.output_format
|
r.output_format = req.output_format
|
||||||
|
|
||||||
|
71
ui/server.py
71
ui/server.py
@ -30,6 +30,9 @@ APP_CONFIG_DEFAULT_MODELS = [
|
|||||||
'custom-model', # Check if user has a custom model, use it first.
|
'custom-model', # Check if user has a custom model, use it first.
|
||||||
'sd-v1-4', # Default fallback.
|
'sd-v1-4', # Default fallback.
|
||||||
]
|
]
|
||||||
|
APP_CONFIG_DEFAULT_VAE = [
|
||||||
|
'vae-ft-mse-840000-ema-pruned', # Default fallback.
|
||||||
|
]
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
@ -129,37 +132,57 @@ def setConfig(config):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str=None):
|
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extension:str, default_models=[]):
|
||||||
|
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
if 'model' in config and model_type in config['model']:
|
||||||
model_name = config['model']['stable-diffusion']
|
model_name = config['model'][model_type]
|
||||||
if model_name:
|
if model_name:
|
||||||
# Check models directory
|
# Check models directory
|
||||||
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
|
||||||
if os.path.exists(models_dir_path + '.ckpt'):
|
if os.path.exists(models_dir_path + model_extension):
|
||||||
return models_dir_path
|
return models_dir_path
|
||||||
if os.path.exists(model_name + '.ckpt'):
|
if os.path.exists(model_name + model_extension):
|
||||||
# Direct Path to file
|
# Direct Path to file
|
||||||
model_name = os.path.abspath(model_name)
|
model_name = os.path.abspath(model_name)
|
||||||
return model_name
|
return model_name
|
||||||
# Default locations
|
# Default locations
|
||||||
if model_name in APP_CONFIG_DEFAULT_MODELS:
|
if model_name in default_models:
|
||||||
default_model_path = os.path.join(SD_DIR, model_name)
|
default_model_path = os.path.join(SD_DIR, model_name)
|
||||||
if os.path.exists(default_model_path + '.ckpt'):
|
if os.path.exists(default_model_path + model_extension):
|
||||||
return default_model_path
|
return default_model_path
|
||||||
# Can't find requested model, check the default paths.
|
# Can't find requested model, check the default paths.
|
||||||
for default_model in APP_CONFIG_DEFAULT_MODELS:
|
for default_model in default_models:
|
||||||
default_model_path = os.path.join(SD_DIR, default_model)
|
for model_dir in model_dirs:
|
||||||
if os.path.exists(default_model_path + '.ckpt'):
|
default_model_path = os.path.join(model_dir, default_model)
|
||||||
|
if os.path.exists(default_model_path + model_extension):
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
|
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
||||||
return default_model_path
|
return default_model_path
|
||||||
raise Exception('No valid models found.')
|
raise Exception('No valid models found.')
|
||||||
|
|
||||||
|
def resolve_ckpt_to_use(model_name:str=None):
|
||||||
|
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extension='.ckpt', default_models=APP_CONFIG_DEFAULT_MODELS)
|
||||||
|
|
||||||
|
def resolve_vae_to_use(ckpt_model_path:str=None):
|
||||||
|
if ckpt_model_path is not None:
|
||||||
|
if os.path.exists(ckpt_model_path + '.vae.pt'):
|
||||||
|
return ckpt_model_path
|
||||||
|
|
||||||
|
ckpt_model_name = os.path.basename(ckpt_model_path)
|
||||||
|
model_dirs = [os.path.join(MODELS_DIR, 'stable-diffusion'), SD_DIR]
|
||||||
|
for model_dir in model_dirs:
|
||||||
|
default_model_path = os.path.join(model_dir, ckpt_model_name)
|
||||||
|
if os.path.exists(default_model_path + '.vae.pt'):
|
||||||
|
return default_model_path
|
||||||
|
|
||||||
|
return resolve_model_to_use(model_name=None, model_type='vae', model_dir='stable-diffusion', model_extension='.vae.pt', default_models=APP_CONFIG_DEFAULT_VAE)
|
||||||
|
|
||||||
class SetAppConfigRequest(BaseModel):
|
class SetAppConfigRequest(BaseModel):
|
||||||
update_branch: str = None
|
update_branch: str = None
|
||||||
render_devices: Union[List[str], List[int], str, int] = None
|
render_devices: Union[List[str], List[int], str, int] = None
|
||||||
|
model_vae: str = None
|
||||||
|
|
||||||
@app.post('/app_config')
|
@app.post('/app_config')
|
||||||
async def setAppConfig(req : SetAppConfigRequest):
|
async def setAppConfig(req : SetAppConfigRequest):
|
||||||
@ -180,6 +203,10 @@ async def setAppConfig(req : SetAppConfigRequest):
|
|||||||
render_devices.append('GPU:' + req.render_devices)
|
render_devices.append('GPU:' + req.render_devices)
|
||||||
if len(render_devices) > 0:
|
if len(render_devices) > 0:
|
||||||
config['render_devices'] = render_devices
|
config['render_devices'] = render_devices
|
||||||
|
if req.model_vae:
|
||||||
|
if 'model' not in config:
|
||||||
|
config['model'] = {}
|
||||||
|
config['model']['vae'] = req.model_vae
|
||||||
try:
|
try:
|
||||||
setConfig(config)
|
setConfig(config)
|
||||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||||
@ -191,28 +218,28 @@ def getModels():
|
|||||||
models = {
|
models = {
|
||||||
'active': {
|
'active': {
|
||||||
'stable-diffusion': 'sd-v1-4',
|
'stable-diffusion': 'sd-v1-4',
|
||||||
|
'vae': '',
|
||||||
},
|
},
|
||||||
'options': {
|
'options': {
|
||||||
'stable-diffusion': ['sd-v1-4'],
|
'stable-diffusion': ['sd-v1-4'],
|
||||||
|
'vae': [],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# custom models
|
# custom models
|
||||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||||
|
for model_type, model_extension in [('stable-diffusion', '.ckpt'), ('vae', '.vae.pt')]:
|
||||||
for file in os.listdir(sd_models_dir):
|
for file in os.listdir(sd_models_dir):
|
||||||
if file.endswith('.ckpt'):
|
if file.endswith(model_extension):
|
||||||
model_name = os.path.splitext(file)[0]
|
model_name = file[:-len(model_extension)]
|
||||||
models['options']['stable-diffusion'].append(model_name)
|
models['options'][model_type].append(model_name)
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||||
if os.path.exists(custom_weight_path):
|
if os.path.exists(custom_weight_path):
|
||||||
models['active']['stable-diffusion'] = 'custom-model'
|
|
||||||
models['options']['stable-diffusion'].append('custom-model')
|
models['options']['stable-diffusion'].append('custom-model')
|
||||||
|
|
||||||
config = getConfig()
|
models['active']['vae'] = os.path.basename(task_manager.default_vae_to_load)
|
||||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
|
||||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
@ -283,7 +310,8 @@ def render(req : task_manager.ImageRequest):
|
|||||||
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
|
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
|
||||||
try:
|
try:
|
||||||
save_model_to_config(req.use_stable_diffusion_model)
|
save_model_to_config(req.use_stable_diffusion_model)
|
||||||
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
|
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
||||||
|
req.use_vae_model = resolve_vae_to_use(ckpt_model_path=req.use_stable_diffusion_model)
|
||||||
new_task = task_manager.render(req)
|
new_task = task_manager.render(req)
|
||||||
response = {
|
response = {
|
||||||
'status': str(task_manager.current_state),
|
'status': str(task_manager.current_state),
|
||||||
@ -361,7 +389,8 @@ logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
|||||||
config = getConfig()
|
config = getConfig()
|
||||||
|
|
||||||
# Start the task_manager
|
# Start the task_manager
|
||||||
task_manager.default_model_to_load = resolve_model_to_use()
|
task_manager.default_model_to_load = resolve_ckpt_to_use()
|
||||||
|
task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load)
|
||||||
if 'render_devices' in config: # Start a new thread for each device.
|
if 'render_devices' in config: # Start a new thread for each device.
|
||||||
if isinstance(config['render_devices'], str):
|
if isinstance(config['render_devices'], str):
|
||||||
config['render_devices'] = config['render_devices'].split(',')
|
config['render_devices'] = config['render_devices'].split(',')
|
||||||
|
Loading…
Reference in New Issue
Block a user