mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-01 06:44:56 +02:00
Pick the right embedding even if it has an underscore
This commit is contained in:
parent
cb7ba96dad
commit
3f278cf2ad
@ -100,7 +100,16 @@ def unload_all(context: Context):
|
|||||||
|
|
||||||
def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True):
|
def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True):
|
||||||
model_names = model_name if isinstance(model_name, list) else [model_name]
|
model_names = model_name if isinstance(model_name, list) else [model_name]
|
||||||
model_paths = [resolve_model_to_use_single(m, model_type, fail_if_not_found) for m in model_names]
|
model_paths = []
|
||||||
|
for m in model_names:
|
||||||
|
if model_type == "embeddings":
|
||||||
|
try:
|
||||||
|
resolve_model_to_use_single(m, model_type, fail_if_not_found)
|
||||||
|
except FileNotFoundError: # try with spaces
|
||||||
|
m = m.replace("_", " ")
|
||||||
|
|
||||||
|
path = resolve_model_to_use_single(m, model_type, fail_if_not_found)
|
||||||
|
model_paths.append(path)
|
||||||
|
|
||||||
return model_paths[0] if len(model_paths) == 1 else model_paths
|
return model_paths[0] if len(model_paths) == 1 else model_paths
|
||||||
|
|
||||||
@ -139,7 +148,9 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
|||||||
return default_model_path
|
return default_model_path
|
||||||
|
|
||||||
if model_name and fail_if_not_found:
|
if model_name and fail_if_not_found:
|
||||||
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
|
raise FileNotFoundError(
|
||||||
|
f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
|
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
|
||||||
@ -323,7 +334,7 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
class MaliciousModelException(Exception):
|
class MaliciousModelException(Exception):
|
||||||
"Raised when picklescan reports a problem with a model"
|
"Raised when picklescan reports a problem with a model"
|
||||||
|
|
||||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None):
|
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter=None):
|
||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
|
|
||||||
tree = list(default_entries)
|
tree = list(default_entries)
|
||||||
@ -366,7 +377,7 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
tree.append((entry.name, scan))
|
tree.append((entry.name, scan))
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def listModels(model_type, nameFilter = None):
|
def listModels(model_type, nameFilter=None):
|
||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
|
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
@ -376,7 +387,9 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
default_tree = models["options"].get(model_type, [])
|
default_tree = models["options"].get(model_type, [])
|
||||||
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter)
|
models["options"][model_type] = scan_directory(
|
||||||
|
models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
|
||||||
|
)
|
||||||
except MaliciousModelException as e:
|
except MaliciousModelException as e:
|
||||||
models["scan-error"] = str(e)
|
models["scan-error"] = str(e)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user