Simplify the code for VAE loading, and make it faster to load VAEs (because we don't reload the entire SD model each time a VAE changes); Record the error and end the thread if the SD model fails to load during startup

This commit is contained in:
cmdr2 2022-12-13 15:45:44 +05:30
parent a483bd0800
commit 6cd0b530c5
2 changed files with 17 additions and 18 deletions

View File

@ -42,6 +42,8 @@ def load_default_models(context: Context):
# load mandatory models
model_loader.load_model(context, 'stable-diffusion')
model_loader.load_model(context, 'vae')
model_loader.load_model(context, 'hypernetwork')
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
@ -93,17 +95,13 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
def reload_models_if_necessary(context: Context, task_data: TaskData):
model_paths_in_req = (
('stable-diffusion', task_data.use_stable_diffusion_model),
('vae', task_data.use_vae_model),
('hypernetwork', task_data.use_hypernetwork_model),
('gfpgan', task_data.use_face_correction),
('realesrgan', task_data.use_upscale),
)
if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model:
context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
context.model_paths['vae'] = task_data.use_vae_model
model_loader.load_model(context, 'stable-diffusion')
for model_type, model_path_in_req in model_paths_in_req:
if context.model_paths.get(model_type) != model_path_in_req:
context.model_paths[model_type] = model_path_in_req

View File

@ -222,12 +222,6 @@ def thread_render(device):
from sd_internal import renderer, model_manager
try:
renderer.init(device)
except Exception as e:
log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
'error': e
}
return
weak_thread_data[threading.current_thread()] = {
'device': renderer.context.device,
@ -239,6 +233,13 @@ def thread_render(device):
model_manager.load_default_models(renderer.context)
current_state = ServerStates.Online
except Exception as e:
log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
'error': e,
'alive': False
}
return
while True:
session_cache.clean()