mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-19 00:07:48 +02:00
Pick the right embedding even if it has an underscore
This commit is contained in:
parent
cb7ba96dad
commit
3f278cf2ad
@ -100,7 +100,16 @@ def unload_all(context: Context):
|
|||||||
|
|
||||||
def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True):
|
def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True):
|
||||||
model_names = model_name if isinstance(model_name, list) else [model_name]
|
model_names = model_name if isinstance(model_name, list) else [model_name]
|
||||||
model_paths = [resolve_model_to_use_single(m, model_type, fail_if_not_found) for m in model_names]
|
model_paths = []
|
||||||
|
for m in model_names:
|
||||||
|
if model_type == "embeddings":
|
||||||
|
try:
|
||||||
|
resolve_model_to_use_single(m, model_type, fail_if_not_found)
|
||||||
|
except FileNotFoundError: # try with spaces
|
||||||
|
m = m.replace("_", " ")
|
||||||
|
|
||||||
|
path = resolve_model_to_use_single(m, model_type, fail_if_not_found)
|
||||||
|
model_paths.append(path)
|
||||||
|
|
||||||
return model_paths[0] if len(model_paths) == 1 else model_paths
|
return model_paths[0] if len(model_paths) == 1 else model_paths
|
||||||
|
|
||||||
@ -139,7 +148,9 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
|||||||
return default_model_path
|
return default_model_path
|
||||||
|
|
||||||
if model_name and fail_if_not_found:
|
if model_name and fail_if_not_found:
|
||||||
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
|
raise FileNotFoundError(
|
||||||
|
f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
|
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
|
||||||
@ -376,7 +387,9 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
default_tree = models["options"].get(model_type, [])
|
default_tree = models["options"].get(model_type, [])
|
||||||
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter)
|
models["options"][model_type] = scan_directory(
|
||||||
|
models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
|
||||||
|
)
|
||||||
except MaliciousModelException as e:
|
except MaliciousModelException as e:
|
||||||
models["scan-error"] = str(e)
|
models["scan-error"] = str(e)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user