From 6cd0b530c5a791ff33968ae15b66b55fd1415dee Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 13 Dec 2022 15:45:44 +0530 Subject: [PATCH] Simplify the code for VAE loading, and make it faster to load VAEs (because we don't reload the entire SD model each time a VAE changes); Record the error and end the thread if the SD model fails to load during startup --- ui/sd_internal/model_manager.py | 10 ++++------ ui/sd_internal/task_manager.py | 25 +++++++++++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 18418624..b6c4c92d 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -42,6 +42,8 @@ def load_default_models(context: Context): # load mandatory models model_loader.load_model(context, 'stable-diffusion') + model_loader.load_model(context, 'vae') + model_loader.load_model(context, 'hypernetwork') def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: @@ -93,17 +95,13 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): def reload_models_if_necessary(context: Context, task_data: TaskData): model_paths_in_req = ( + ('stable-diffusion', task_data.use_stable_diffusion_model), + ('vae', task_data.use_vae_model), ('hypernetwork', task_data.use_hypernetwork_model), ('gfpgan', task_data.use_face_correction), ('realesrgan', task_data.use_upscale), ) - if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model: - context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model - context.model_paths['vae'] = task_data.use_vae_model - - model_loader.load_model(context, 'stable-diffusion') - for model_type, model_path_in_req in model_paths_in_req: if context.model_paths.get(model_type) != model_path_in_req: context.model_paths[model_type] = model_path_in_req diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 0db48d9d..3b8f6082 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -222,24 +222,25 @@ def thread_render(device): from sd_internal import renderer, model_manager try: renderer.init(device) + + weak_thread_data[threading.current_thread()] = { + 'device': renderer.context.device, + 'device_name': renderer.context.device_name, + 'alive': True + } + + current_state = ServerStates.LoadingModel + model_manager.load_default_models(renderer.context) + + current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { - 'error': e + 'error': e, + 'alive': False } return - weak_thread_data[threading.current_thread()] = { - 'device': renderer.context.device, - 'device_name': renderer.context.device_name, - 'alive': True - } - - current_state = ServerStates.LoadingModel - model_manager.load_default_models(renderer.context) - - current_state = ServerStates.Online - while True: session_cache.clean() task_cache.clean()