mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-26 08:08:58 +01:00
Support custom VAE files; Use vae-ft-mse-840000-ema-pruned as the default VAE, which can be overridden by putting a .vae.pt file inside models/stable-diffusion with the same name as the ckpt model file. The UI / System Settings allows setting the default VAE model to use
This commit is contained in:
parent
79a7cd2938
commit
a8c16e39b8
@ -321,6 +321,36 @@ echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
||||
|
||||
|
||||
|
||||
@if exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
|
||||
for %%I in ("..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt") do if "%%~zI" EQU "334695179" (
|
||||
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
|
||||
) else (
|
||||
echo. & echo "The default VAE (sd-vae-ft-mse-original) file present at models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt is invalid. It is only %%~zI bytes in size. Re-downloading.." & echo.
|
||||
del "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt"
|
||||
)
|
||||
)
|
||||
|
||||
@if not exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
|
||||
@echo. & echo "Downloading data files (weights) for the default VAE (sd-vae-ft-mse-original).." & echo.
|
||||
|
||||
@call curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt
|
||||
|
||||
@if exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
|
||||
for %%I in ("..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt") do if "%%~zI" NEQ "334695179" (
|
||||
echo. & echo "Error: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: %%~zI" & echo.
|
||||
echo. & echo "Error downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
) else (
|
||||
@echo. & echo "Error downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
||||
@if "%ERRORLEVEL%" NEQ "0" (
|
||||
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
|
||||
|
@ -300,6 +300,38 @@ if [ ! -f "RealESRGAN_x4plus_anime_6B.pth" ]; then
|
||||
fi
|
||||
|
||||
|
||||
if [ -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
|
||||
model_size=`find ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt -printf "%s"`
|
||||
|
||||
if [ "$model_size" -eq "334695179" ]; then
|
||||
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
|
||||
else
|
||||
printf "\n\nThe model file present at models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt is invalid. It is only $model_size bytes in size. Re-downloading.."
|
||||
rm ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
|
||||
echo "Downloading data files (weights) for the default VAE (sd-vae-ft-mse-original).."
|
||||
|
||||
curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt
|
||||
|
||||
if [ -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
|
||||
model_size=`find ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt -printf "%s"`
|
||||
if [ ! "$model_size" -eq "334695179" ]; then
|
||||
printf "\n\nError: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: $model_size\n\n"
|
||||
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
|
||||
read -p "Press any key to continue"
|
||||
exit
|
||||
fi
|
||||
else
|
||||
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
|
||||
read -p "Press any key to continue"
|
||||
exit
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
||||
echo sd_install_complete >> ../scripts/install_status.txt
|
||||
|
@ -18,7 +18,7 @@
|
||||
<div id="container">
|
||||
<div id="top-nav">
|
||||
<div id="logo">
|
||||
<h1>Stable Diffusion UI <small>v2.3.7 <span id="updateBranchLabel"></span></small></h1>
|
||||
<h1>Stable Diffusion UI <small>v2.3.8 <span id="updateBranchLabel"></span></small></h1>
|
||||
</div>
|
||||
<ul id="top-nav-items">
|
||||
<li class="dropdown">
|
||||
@ -38,6 +38,12 @@
|
||||
<br/>
|
||||
<li><label for="theme">Theme: </label><select id="theme" name="theme"><option value="theme-default">Default</option></select></li>
|
||||
<li><input id="save_to_disk" name="save_to_disk" type="checkbox"> <label for="save_to_disk">Automatically save to <input id="diskPath" name="diskPath" size="40" disabled></label></li>
|
||||
<li>
|
||||
<label for="default_vae_model">Default VAE:</label></td><td>
|
||||
<select id="default_vae_model" name="default_vae_model">
|
||||
<!-- <option value="vae-ft-mse-840000-ema-pruned" selected>vae-ft-mse-840000-ema-pruned</option> -->
|
||||
</select>
|
||||
</li>
|
||||
<li><input id="sound_toggle" name="sound_toggle" type="checkbox" checked> <label for="sound_toggle">Play sound on task completion</label></li>
|
||||
<li><input id="turbo" name="turbo" type="checkbox" checked> <label for="turbo">Turbo mode <small>(generates images faster, but uses an additional 1 GB of GPU memory)</small></label></li>
|
||||
<li><input id="use_cpu" name="use_cpu" type="checkbox"> <label for="use_cpu">Use CPU instead of GPU <small>(warning: this will be *very* slow)</small></label></li>
|
||||
@ -273,7 +279,7 @@
|
||||
<script src="media/js/inpainting-editor.js?v=1"></script>
|
||||
<script src="media/js/image-modifiers.js?v=3"></script>
|
||||
<script src="media/js/auto-save.js?v=2"></script>
|
||||
<script src="media/js/main.js?v=5"></script>
|
||||
<script src="media/js/main.js?v=6"></script>
|
||||
<script src="media/js/themes.js?v=2"></script>
|
||||
<script>
|
||||
async function init() {
|
||||
|
@ -39,6 +39,7 @@ let useFaceCorrectionField = document.querySelector("#use_face_correction")
|
||||
let useUpscalingField = document.querySelector("#use_upscale")
|
||||
let upscaleModelField = document.querySelector("#upscale_model")
|
||||
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
||||
let vaeModelField = document.querySelector('#default_vae_model')
|
||||
let outputFormatField = document.querySelector('#output_format')
|
||||
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
||||
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
||||
@ -1104,15 +1105,13 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
|
||||
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
||||
updatePromptStrength()
|
||||
|
||||
useBetaChannelField.addEventListener('click', async function(e) {
|
||||
if (!isServerAvailable()) {
|
||||
// logError('The server is still starting up..')
|
||||
alert('The server is still starting up..')
|
||||
e.preventDefault()
|
||||
return false
|
||||
}
|
||||
|
||||
let updateBranch = (this.checked ? 'beta' : 'main')
|
||||
async function changeAppConfig(configDelta) {
|
||||
// if (!isServerAvailable()) {
|
||||
// // logError('The server is still starting up..')
|
||||
// alert('The server is still starting up..')
|
||||
// e.preventDefault()
|
||||
// return false
|
||||
// }
|
||||
|
||||
try {
|
||||
let res = await fetch('/app_config', {
|
||||
@ -1120,9 +1119,7 @@ useBetaChannelField.addEventListener('click', async function(e) {
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
'update_branch': updateBranch
|
||||
})
|
||||
body: JSON.stringify(configDelta)
|
||||
})
|
||||
res = await res.json()
|
||||
|
||||
@ -1130,6 +1127,20 @@ useBetaChannelField.addEventListener('click', async function(e) {
|
||||
} catch (e) {
|
||||
console.log('set config status error', e)
|
||||
}
|
||||
}
|
||||
|
||||
useBetaChannelField.addEventListener('click', async function(e) {
|
||||
let updateBranch = (this.checked ? 'beta' : 'main')
|
||||
|
||||
await changeAppConfig({
|
||||
'update_branch': updateBranch
|
||||
})
|
||||
})
|
||||
|
||||
vaeModelField.addEventListener('change', async function() {
|
||||
await changeAppConfig({
|
||||
'model_vae': this.value
|
||||
})
|
||||
})
|
||||
|
||||
async function getAppConfig() {
|
||||
@ -1155,21 +1166,28 @@ async function getModels() {
|
||||
let res = await fetch('/get/models')
|
||||
const models = await res.json()
|
||||
|
||||
// let activeModel = models['active']
|
||||
let activeModels = models['active']
|
||||
let modelOptions = models['options']
|
||||
let stableDiffusionOptions = modelOptions['stable-diffusion']
|
||||
let vaeOptions = modelOptions['vae']
|
||||
let activeVae = activeModels['vae']
|
||||
|
||||
stableDiffusionOptions.forEach(modelName => {
|
||||
let modelOption = document.createElement('option')
|
||||
modelOption.value = modelName
|
||||
modelOption.innerText = modelName
|
||||
function createModelOptions(modelField, selectedModel) {
|
||||
return function(modelName) {
|
||||
let modelOption = document.createElement('option')
|
||||
modelOption.value = modelName
|
||||
modelOption.innerText = modelName
|
||||
|
||||
if (modelName === selectedModel) {
|
||||
modelOption.selected = true
|
||||
if (modelName === selectedModel) {
|
||||
modelOption.selected = true
|
||||
}
|
||||
|
||||
modelField.appendChild(modelOption)
|
||||
}
|
||||
}
|
||||
|
||||
stableDiffusionModelField.appendChild(modelOption)
|
||||
})
|
||||
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedModel))
|
||||
vaeOptions.forEach(createModelOptions(vaeModelField, activeVae))
|
||||
|
||||
// TODO: set default for model here too
|
||||
SETTINGS[model_setting_key].default = stableDiffusionOptions[0]
|
||||
|
@ -23,6 +23,7 @@ class Request:
|
||||
use_face_correction: str = None # or "GFPGANv1.3"
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
|
||||
@ -45,6 +46,7 @@ class Request:
|
||||
"use_face_correction": self.use_face_correction,
|
||||
"use_upscale": self.use_upscale,
|
||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||
"use_vae_model": self.use_vae_model,
|
||||
"output_format": self.output_format,
|
||||
}
|
||||
|
||||
@ -67,6 +69,7 @@ class Request:
|
||||
use_face_correction: {self.use_face_correction}
|
||||
use_upscale: {self.use_upscale}
|
||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||
use_vae_model: {self.use_vae_model}
|
||||
show_only_filtered_image: {self.show_only_filtered_image}
|
||||
output_format: {self.output_format}
|
||||
|
||||
|
@ -87,6 +87,7 @@ def device_init(device_selection=None):
|
||||
thread_data.temp_images = {}
|
||||
|
||||
thread_data.ckpt_file = None
|
||||
thread_data.vae_file = None
|
||||
thread_data.gfpgan_file = None
|
||||
thread_data.real_esrgan_file = None
|
||||
|
||||
@ -184,7 +185,7 @@ def load_model_ckpt():
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.precision = 'full'
|
||||
|
||||
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
|
||||
print('loading', thread_data.ckpt_file + '.ckpt', 'to', thread_data.device, 'using precision', thread_data.precision)
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
@ -231,6 +232,16 @@ def load_model_ckpt():
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
|
||||
if thread_data.vae_file is not None:
|
||||
if os.path.exists(thread_data.vae_file + '.vae.pt'):
|
||||
print(f"Loading VAE weights from: {thread_data.vae_file}.vae.pt")
|
||||
vae_ckpt = torch.load(thread_data.vae_file + '.vae.pt', map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f'Cannot find VAE file: {thread_data.vae_file}.vae.pt')
|
||||
|
||||
modelFS.eval()
|
||||
if thread_data.device != 'cpu':
|
||||
if thread_data.reduced_memory:
|
||||
@ -459,8 +470,9 @@ def do_mk_img(req: Request):
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
|
||||
needs_model_reload = False
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
thread_data.vae_file = req.use_vae_model
|
||||
needs_model_reload = True
|
||||
|
||||
if thread_data.has_valid_gpu:
|
||||
|
@ -73,6 +73,7 @@ class ImageRequest(BaseModel):
|
||||
use_face_correction: str = None # or "GFPGANv1.3"
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
|
||||
@ -170,27 +171,34 @@ render_threads = []
|
||||
current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_model_path = None
|
||||
current_vae_path = None
|
||||
tasks_queue = []
|
||||
task_cache = TaskCache()
|
||||
default_model_to_load = None
|
||||
default_vae_to_load = None
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
|
||||
def preload_model(file_path=None):
|
||||
def preload_model(ckpt_file_path=None, vae_file_path=None):
|
||||
global current_state, current_state_error, current_model_path
|
||||
if file_path == None:
|
||||
file_path = default_model_to_load
|
||||
if file_path == current_model_path:
|
||||
if ckpt_file_path == None:
|
||||
ckpt_file_path = default_model_to_load
|
||||
if vae_file_path == None:
|
||||
vae_file_path = default_vae_to_load
|
||||
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
||||
return
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
from . import runtime
|
||||
runtime.thread_data.ckpt_file = file_path
|
||||
runtime.thread_data.ckpt_file = ckpt_file_path
|
||||
runtime.thread_data.vae_file = vae_file_path
|
||||
runtime.load_model_ckpt()
|
||||
current_model_path = file_path
|
||||
current_model_path = ckpt_file_path
|
||||
current_vae_path = vae_file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
current_model_path = None
|
||||
current_vae_path = None
|
||||
current_state_error = e
|
||||
current_state = ServerStates.Unavailable
|
||||
print(traceback.format_exc())
|
||||
@ -240,7 +248,7 @@ def thread_get_next_task():
|
||||
manager_lock.release()
|
||||
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error, current_model_path
|
||||
global current_state, current_state_error, current_model_path, current_vae_path
|
||||
from . import runtime
|
||||
try:
|
||||
runtime.device_init(device)
|
||||
@ -289,6 +297,7 @@ def thread_render(device):
|
||||
if current_state == ServerStates.LoadingModel:
|
||||
current_state = ServerStates.Rendering
|
||||
current_model_path = task.request.use_stable_diffusion_model
|
||||
current_vae_path = task.request.use_vae_model
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
runtime.thread_data.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
@ -411,6 +420,7 @@ def render(req : ImageRequest):
|
||||
r.use_upscale: str = req.use_upscale
|
||||
r.use_face_correction = req.use_face_correction
|
||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
||||
r.use_vae_model = req.use_vae_model
|
||||
r.show_only_filtered_image = req.show_only_filtered_image
|
||||
r.output_format = req.output_format
|
||||
|
||||
|
77
ui/server.py
77
ui/server.py
@ -30,6 +30,9 @@ APP_CONFIG_DEFAULT_MODELS = [
|
||||
'custom-model', # Check if user has a custom model, use it first.
|
||||
'sd-v1-4', # Default fallback.
|
||||
]
|
||||
APP_CONFIG_DEFAULT_VAE = [
|
||||
'vae-ft-mse-840000-ema-pruned', # Default fallback.
|
||||
]
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@ -129,37 +132,57 @@ def setConfig(config):
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
def resolve_model_to_use(model_name:str=None):
|
||||
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extension:str, default_models=[]):
|
||||
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
|
||||
if not model_name: # When None try user configured model.
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
model_name = config['model']['stable-diffusion']
|
||||
if 'model' in config and model_type in config['model']:
|
||||
model_name = config['model'][model_type]
|
||||
if model_name:
|
||||
# Check models directory
|
||||
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
if os.path.exists(models_dir_path + '.ckpt'):
|
||||
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
|
||||
if os.path.exists(models_dir_path + model_extension):
|
||||
return models_dir_path
|
||||
if os.path.exists(model_name + '.ckpt'):
|
||||
if os.path.exists(model_name + model_extension):
|
||||
# Direct Path to file
|
||||
model_name = os.path.abspath(model_name)
|
||||
return model_name
|
||||
# Default locations
|
||||
if model_name in APP_CONFIG_DEFAULT_MODELS:
|
||||
if model_name in default_models:
|
||||
default_model_path = os.path.join(SD_DIR, model_name)
|
||||
if os.path.exists(default_model_path + '.ckpt'):
|
||||
if os.path.exists(default_model_path + model_extension):
|
||||
return default_model_path
|
||||
# Can't find requested model, check the default paths.
|
||||
for default_model in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, default_model)
|
||||
if os.path.exists(default_model_path + '.ckpt'):
|
||||
if model_name is not None:
|
||||
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
|
||||
return default_model_path
|
||||
for default_model in default_models:
|
||||
for model_dir in model_dirs:
|
||||
default_model_path = os.path.join(model_dir, default_model)
|
||||
if os.path.exists(default_model_path + model_extension):
|
||||
if model_name is not None:
|
||||
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
||||
return default_model_path
|
||||
raise Exception('No valid models found.')
|
||||
|
||||
def resolve_ckpt_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extension='.ckpt', default_models=APP_CONFIG_DEFAULT_MODELS)
|
||||
|
||||
def resolve_vae_to_use(ckpt_model_path:str=None):
|
||||
if ckpt_model_path is not None:
|
||||
if os.path.exists(ckpt_model_path + '.vae.pt'):
|
||||
return ckpt_model_path
|
||||
|
||||
ckpt_model_name = os.path.basename(ckpt_model_path)
|
||||
model_dirs = [os.path.join(MODELS_DIR, 'stable-diffusion'), SD_DIR]
|
||||
for model_dir in model_dirs:
|
||||
default_model_path = os.path.join(model_dir, ckpt_model_name)
|
||||
if os.path.exists(default_model_path + '.vae.pt'):
|
||||
return default_model_path
|
||||
|
||||
return resolve_model_to_use(model_name=None, model_type='vae', model_dir='stable-diffusion', model_extension='.vae.pt', default_models=APP_CONFIG_DEFAULT_VAE)
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
model_vae: str = None
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
@ -180,6 +203,10 @@ async def setAppConfig(req : SetAppConfigRequest):
|
||||
render_devices.append('GPU:' + req.render_devices)
|
||||
if len(render_devices) > 0:
|
||||
config['render_devices'] = render_devices
|
||||
if req.model_vae:
|
||||
if 'model' not in config:
|
||||
config['model'] = {}
|
||||
config['model']['vae'] = req.model_vae
|
||||
try:
|
||||
setConfig(config)
|
||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||
@ -191,28 +218,28 @@ def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
'vae': '',
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
'vae': [],
|
||||
},
|
||||
}
|
||||
|
||||
# custom models
|
||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||
for file in os.listdir(sd_models_dir):
|
||||
if file.endswith('.ckpt'):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
models['options']['stable-diffusion'].append(model_name)
|
||||
for model_type, model_extension in [('stable-diffusion', '.ckpt'), ('vae', '.vae.pt')]:
|
||||
for file in os.listdir(sd_models_dir):
|
||||
if file.endswith(model_extension):
|
||||
model_name = file[:-len(model_extension)]
|
||||
models['options'][model_type].append(model_name)
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['active']['stable-diffusion'] = 'custom-model'
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
||||
models['active']['vae'] = os.path.basename(task_manager.default_vae_to_load)
|
||||
|
||||
return models
|
||||
|
||||
@ -283,7 +310,8 @@ def render(req : task_manager.ImageRequest):
|
||||
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
|
||||
try:
|
||||
save_model_to_config(req.use_stable_diffusion_model)
|
||||
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
|
||||
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
||||
req.use_vae_model = resolve_vae_to_use(ckpt_model_path=req.use_stable_diffusion_model)
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
@ -361,7 +389,8 @@ logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
config = getConfig()
|
||||
|
||||
# Start the task_manager
|
||||
task_manager.default_model_to_load = resolve_model_to_use()
|
||||
task_manager.default_model_to_load = resolve_ckpt_to_use()
|
||||
task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load)
|
||||
if 'render_devices' in config: # Start a new thread for each device.
|
||||
if isinstance(config['render_devices'], str):
|
||||
config['render_devices'] = config['render_devices'].split(',')
|
||||
|
Loading…
Reference in New Issue
Block a user