Support custom VAE files; Use vae-ft-mse-840000-ema-pruned as the default VAE, which can be overridden by putting a .vae.pt file inside models/stable-diffusion with the same name as the ckpt model file. The UI / System Settings allows setting the default VAE model to use

This commit is contained in:
cmdr2 2022-10-28 20:06:44 +05:30
parent 79a7cd2938
commit a8c16e39b8
8 changed files with 196 additions and 56 deletions

View File

@ -321,6 +321,36 @@ echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
@if exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
for %%I in ("..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt") do if "%%~zI" EQU "334695179" (
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
) else (
echo. & echo "The default VAE (sd-vae-ft-mse-original) file present at models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt is invalid. It is only %%~zI bytes in size. Re-downloading.." & echo.
del "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt"
)
)
@if not exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
@echo. & echo "Downloading data files (weights) for the default VAE (sd-vae-ft-mse-original).." & echo.
@call curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt
@if exist "..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt" (
for %%I in ("..\models\stable-diffusion\vae-ft-mse-840000-ema-pruned.vae.pt") do if "%%~zI" NEQ "334695179" (
echo. & echo "Error: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: %%~zI" & echo.
echo. & echo "Error downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
) else (
@echo. & echo "Error downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (
@echo sd_weights_downloaded >> ..\scripts\install_status.txt

View File

@ -300,6 +300,38 @@ if [ ! -f "RealESRGAN_x4plus_anime_6B.pth" ]; then
fi
if [ -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
model_size=`find ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt -printf "%s"`
if [ "$model_size" -eq "334695179" ]; then
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
else
printf "\n\nThe model file present at models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt is invalid. It is only $model_size bytes in size. Re-downloading.."
rm ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt
fi
fi
if [ ! -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
echo "Downloading data files (weights) for the default VAE (sd-vae-ft-mse-original).."
curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt
if [ -f "../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt" ]; then
model_size=`find ../models/stable-diffusion/vae-ft-mse-840000-ema-pruned.vae.pt -printf "%s"`
if [ ! "$model_size" -eq "334695179" ]; then
printf "\n\nError: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: $model_size\n\n"
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
read -p "Press any key to continue"
exit
fi
else
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
read -p "Press any key to continue"
exit
fi
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt
echo sd_install_complete >> ../scripts/install_status.txt

View File

@ -18,7 +18,7 @@
<div id="container">
<div id="top-nav">
<div id="logo">
<h1>Stable Diffusion UI <small>v2.3.7 <span id="updateBranchLabel"></span></small></h1>
<h1>Stable Diffusion UI <small>v2.3.8 <span id="updateBranchLabel"></span></small></h1>
</div>
<ul id="top-nav-items">
<li class="dropdown">
@ -38,6 +38,12 @@
<br/>
<li><label for="theme">Theme: </label><select id="theme" name="theme"><option value="theme-default">Default</option></select></li>
<li><input id="save_to_disk" name="save_to_disk" type="checkbox"> <label for="save_to_disk">Automatically save to <input id="diskPath" name="diskPath" size="40" disabled></label></li>
<li>
<label for="default_vae_model">Default VAE:</label></td><td>
<select id="default_vae_model" name="default_vae_model">
<!-- <option value="vae-ft-mse-840000-ema-pruned" selected>vae-ft-mse-840000-ema-pruned</option> -->
</select>
</li>
<li><input id="sound_toggle" name="sound_toggle" type="checkbox" checked> <label for="sound_toggle">Play sound on task completion</label></li>
<li><input id="turbo" name="turbo" type="checkbox" checked> <label for="turbo">Turbo mode <small>(generates images faster, but uses an additional 1 GB of GPU memory)</small></label></li>
<li><input id="use_cpu" name="use_cpu" type="checkbox"> <label for="use_cpu">Use CPU instead of GPU <small>(warning: this will be *very* slow)</small></label></li>
@ -273,7 +279,7 @@
<script src="media/js/inpainting-editor.js?v=1"></script>
<script src="media/js/image-modifiers.js?v=3"></script>
<script src="media/js/auto-save.js?v=2"></script>
<script src="media/js/main.js?v=5"></script>
<script src="media/js/main.js?v=6"></script>
<script src="media/js/themes.js?v=2"></script>
<script>
async function init() {

View File

@ -39,6 +39,7 @@ let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model")
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
let vaeModelField = document.querySelector('#default_vae_model')
let outputFormatField = document.querySelector('#output_format')
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel")
@ -1104,15 +1105,13 @@ promptStrengthSlider.addEventListener('input', updatePromptStrength)
promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength()
useBetaChannelField.addEventListener('click', async function(e) {
if (!isServerAvailable()) {
// logError('The server is still starting up..')
alert('The server is still starting up..')
e.preventDefault()
return false
}
let updateBranch = (this.checked ? 'beta' : 'main')
async function changeAppConfig(configDelta) {
// if (!isServerAvailable()) {
// // logError('The server is still starting up..')
// alert('The server is still starting up..')
// e.preventDefault()
// return false
// }
try {
let res = await fetch('/app_config', {
@ -1120,9 +1119,7 @@ useBetaChannelField.addEventListener('click', async function(e) {
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
'update_branch': updateBranch
})
body: JSON.stringify(configDelta)
})
res = await res.json()
@ -1130,6 +1127,20 @@ useBetaChannelField.addEventListener('click', async function(e) {
} catch (e) {
console.log('set config status error', e)
}
}
useBetaChannelField.addEventListener('click', async function(e) {
let updateBranch = (this.checked ? 'beta' : 'main')
await changeAppConfig({
'update_branch': updateBranch
})
})
vaeModelField.addEventListener('change', async function() {
await changeAppConfig({
'model_vae': this.value
})
})
async function getAppConfig() {
@ -1155,21 +1166,28 @@ async function getModels() {
let res = await fetch('/get/models')
const models = await res.json()
// let activeModel = models['active']
let activeModels = models['active']
let modelOptions = models['options']
let stableDiffusionOptions = modelOptions['stable-diffusion']
let vaeOptions = modelOptions['vae']
let activeVae = activeModels['vae']
stableDiffusionOptions.forEach(modelName => {
let modelOption = document.createElement('option')
modelOption.value = modelName
modelOption.innerText = modelName
function createModelOptions(modelField, selectedModel) {
return function(modelName) {
let modelOption = document.createElement('option')
modelOption.value = modelName
modelOption.innerText = modelName
if (modelName === selectedModel) {
modelOption.selected = true
if (modelName === selectedModel) {
modelOption.selected = true
}
modelField.appendChild(modelOption)
}
}
stableDiffusionModelField.appendChild(modelOption)
})
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedModel))
vaeOptions.forEach(createModelOptions(vaeModelField, activeVae))
// TODO: set default for model here too
SETTINGS[model_setting_key].default = stableDiffusionOptions[0]

View File

@ -23,6 +23,7 @@ class Request:
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
@ -45,6 +46,7 @@ class Request:
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
"output_format": self.output_format,
}
@ -67,6 +69,7 @@ class Request:
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}

View File

@ -87,6 +87,7 @@ def device_init(device_selection=None):
thread_data.temp_images = {}
thread_data.ckpt_file = None
thread_data.vae_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
@ -184,7 +185,7 @@ def load_model_ckpt():
if thread_data.device == 'cpu':
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
print('loading', thread_data.ckpt_file + '.ckpt', 'to', thread_data.device, 'using precision', thread_data.precision)
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
@ -231,6 +232,16 @@ def load_model_ckpt():
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
if thread_data.vae_file is not None:
if os.path.exists(thread_data.vae_file + '.vae.pt'):
print(f"Loading VAE weights from: {thread_data.vae_file}.vae.pt")
vae_ckpt = torch.load(thread_data.vae_file + '.vae.pt', map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f'Cannot find VAE file: {thread_data.vae_file}.vae.pt')
modelFS.eval()
if thread_data.device != 'cpu':
if thread_data.reduced_memory:
@ -459,8 +470,9 @@ def do_mk_img(req: Request):
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.has_valid_gpu:

View File

@ -73,6 +73,7 @@ class ImageRequest(BaseModel):
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
@ -170,27 +171,34 @@ render_threads = []
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
current_vae_path = None
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
default_vae_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(file_path=None):
def preload_model(ckpt_file_path=None, vae_file_path=None):
global current_state, current_state_error, current_model_path
if file_path == None:
file_path = default_model_to_load
if file_path == current_model_path:
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.thread_data.ckpt_file = file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
current_model_path = file_path
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
@ -240,7 +248,7 @@ def thread_get_next_task():
manager_lock.release()
def thread_render(device):
global current_state, current_state_error, current_model_path
global current_state, current_state_error, current_model_path, current_vae_path
from . import runtime
try:
runtime.device_init(device)
@ -289,6 +297,7 @@ def thread_render(device):
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
@ -411,6 +420,7 @@ def render(req : ImageRequest):
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format

View File

@ -30,6 +30,9 @@ APP_CONFIG_DEFAULT_MODELS = [
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
APP_CONFIG_DEFAULT_VAE = [
'vae-ft-mse-840000-ema-pruned', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
@ -129,37 +132,57 @@ def setConfig(config):
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extension:str, default_models=[]):
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + '.ckpt'):
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
if model_name in default_models:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
if os.path.exists(default_model_path + model_extension):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model)
if os.path.exists(default_model_path + '.ckpt'):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
return default_model_path
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extension='.ckpt', default_models=APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(ckpt_model_path:str=None):
if ckpt_model_path is not None:
if os.path.exists(ckpt_model_path + '.vae.pt'):
return ckpt_model_path
ckpt_model_name = os.path.basename(ckpt_model_path)
model_dirs = [os.path.join(MODELS_DIR, 'stable-diffusion'), SD_DIR]
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, ckpt_model_name)
if os.path.exists(default_model_path + '.vae.pt'):
return default_model_path
return resolve_model_to_use(model_name=None, model_type='vae', model_dir='stable-diffusion', model_extension='.vae.pt', default_models=APP_CONFIG_DEFAULT_VAE)
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
@ -180,6 +203,10 @@ async def setAppConfig(req : SetAppConfigRequest):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
if req.model_vae:
if 'model' not in config:
config['model'] = {}
config['model']['vae'] = req.model_vae
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
@ -191,28 +218,28 @@ def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
for model_type, model_extension in [('stable-diffusion', '.ckpt'), ('vae', '.vae.pt')]:
for file in os.listdir(sd_models_dir):
if file.endswith(model_extension):
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
models['active']['vae'] = os.path.basename(task_manager.default_vae_to_load)
return models
@ -283,7 +310,8 @@ def render(req : task_manager.ImageRequest):
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
try:
save_model_to_config(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(ckpt_model_path=req.use_stable_diffusion_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
@ -361,7 +389,8 @@ logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
config = getConfig()
# Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load)
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')