From 4a62d4e76e5194793a86c400f2d92066cd6755f9 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 4 Jan 2025 18:04:43 +0530 Subject: [PATCH] Workaround for when the context doesn't have a model_load_errors field; Not sure why it doesn't have it --- ui/easydiffusion/model_manager.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index e904a315..fb167399 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -86,7 +86,7 @@ def load_default_models(context: Context): scan_model=context.model_paths[model_type] != None and not context.model_paths[model_type].endswith(".safetensors"), ) - if model_type in context.model_load_errors: + if hasattr(context, "model_load_errors") and model_type in context.model_load_errors: del context.model_load_errors[model_type] except Exception as e: log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]") @@ -98,6 +98,8 @@ def load_default_models(context: Context): log.exception(e) del context.model_paths[model_type] + if not hasattr(context, "model_load_errors"): + context.model_load_errors = {} context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks @@ -193,11 +195,13 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models extra_params = models_data.model_params.get(model_type, {}) try: action_fn(context, model_type, scan_model=False, **extra_params) # we've scanned them already - if model_type in context.model_load_errors: + if hasattr(context, "model_load_errors") and model_type in context.model_load_errors: del context.model_load_errors[model_type] except Exception as e: log.exception(e) if action_fn == backend.load_model: + if not hasattr(context, "model_load_errors"): + context.model_load_errors = {} context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks @@ -234,7 +238,7 @@ def resolve_model_paths(models_data: ModelsData): def fail_if_models_did_not_load(context: Context): for model_type in KNOWN_MODEL_TYPES: - if model_type in context.model_load_errors: + if hasattr(context, "model_load_errors") and model_type in context.model_load_errors: e = context.model_load_errors[model_type] raise Exception(f"Could not load the {model_type} model! Reason: " + e)