Don't crash if a VAE file fails to load

This commit is contained in:
cmdr2 2022-11-18 13:10:56 +05:30
parent 6756fb4fe7
commit 025d4df774
3 changed files with 20 additions and 10 deletions

View File

@ -20,6 +20,7 @@
- A `What's New?` tab in the UI
### Detailed changelog
* 2.4.7 - 18 Nov 2022 - Don't crash if a VAE file fails to load
* 2.4.7 - 17 Nov 2022 - Fix a bug where Face Correction (GFPGAN) would fail on cuda:N (i.e. GPUs other than cuda:0), as well as fail on CPU if the system had an incompatible GPU.
* 2.4.6 - 16 Nov 2022 - Fix a regression in VRAM usage during startup, which caused 'Out of Memory' errors when starting on GPUs with 4gb (or less) VRAM
* 2.4.5 - 16 Nov 2022 - Add checkbox for "Open browser on startup".

View File

@ -140,15 +140,24 @@ def load_model_ckpt():
_, _ = modelFS.load_state_dict(sd, strict=False)
if thread_data.vae_file is not None:
for model_extension in ['.ckpt', '.vae.pt']:
if os.path.exists(thread_data.vae_file + model_extension):
print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}")
vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
break
else:
print(f'Cannot find VAE file: {thread_data.vae_file}{model_extension}')
try:
loaded = False
for model_extension in ['.ckpt', '.vae.pt']:
if os.path.exists(thread_data.vae_file + model_extension):
print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}")
vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
loaded = True
break
if not loaded:
print(f'Cannot find VAE: {thread_data.vae_file}')
thread_data.vae_file = None
except:
print(traceback.format_exc())
print(f'Could not load VAE: {thread_data.vae_file}')
thread_data.vae_file = None
modelFS.eval()
# if thread_data.device != 'cpu':

View File

@ -436,7 +436,7 @@ def stop_render_thread(device):
try:
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
except:
print(traceback.format_exec())
print(traceback.format_exc())
return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)