diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index e08328fe..fefcc916 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -43,6 +43,7 @@ def load_default_models(context: Context): # init default model paths for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) + set_model_config_path(context, model_type) load_model(context, model_type) def unload_all(context: Context): @@ -101,16 +102,26 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): if set_vram_optimizations(context): # reload SD models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] - if 'stable-diffusion' in models_to_reload: - quick_hash = hash_file_quick(models_to_reload['stable-diffusion']) - known_model_info = get_model_info_from_db(quick_hash=quick_hash) - for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req + set_model_config_path(context, model_type) action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn(context, model_type, scan_model=False) # we've scanned them already +def set_model_config_path(context: Context, model_type: str): + if model_type != 'stable-diffusion': + return + + context.model_configs['stable-diffusion'] = None # reset this, to avoid loading the last config + + # look for a yaml file next to the model, otherwise let sdkit match it to a known model + model_path = context.model_paths['stable-diffusion'] + file_path, _ = os.path.splitext(model_path) + config_path = file_path + '.yaml' + if os.path.exists(config_path): + context.model_configs['stable-diffusion'] = config_path + def resolve_model_paths(task_data: TaskData): task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')