diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 99c1c84c..aef25e46 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -100,7 +100,16 @@ def unload_all(context: Context): def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True): model_names = model_name if isinstance(model_name, list) else [model_name] - model_paths = [resolve_model_to_use_single(m, model_type, fail_if_not_found) for m in model_names] + model_paths = [] + for m in model_names: + if model_type == "embeddings": + try: + resolve_model_to_use_single(m, model_type, fail_if_not_found) + except FileNotFoundError: # try with spaces + m = m.replace("_", " ") + + path = resolve_model_to_use_single(m, model_type, fail_if_not_found) + model_paths.append(path) return model_paths[0] if len(model_paths) == 1 else model_paths @@ -139,7 +148,9 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, return default_model_path if model_name and fail_if_not_found: - raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?") + raise FileNotFoundError( + f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?" + ) def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []): @@ -323,7 +334,7 @@ def getModels(scan_for_malicious: bool = True): class MaliciousModelException(Exception): "Raised when picklescan reports a problem with a model" - def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None): + def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter=None): nonlocal models_scanned tree = list(default_entries) @@ -366,7 +377,7 @@ def getModels(scan_for_malicious: bool = True): tree.append((entry.name, scan)) return tree - def listModels(model_type, nameFilter = None): + def listModels(model_type, nameFilter=None): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) @@ -376,7 +387,9 @@ def getModels(scan_for_malicious: bool = True): try: default_tree = models["options"].get(model_type, []) - models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter) + models["options"][model_type] = scan_directory( + models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter + ) except MaliciousModelException as e: models["scan-error"] = str(e)