Pick the right embedding even if it has an underscore

This commit is contained in:
cmdr2 2023-08-24 17:36:49 +05:30
parent cb7ba96dad
commit 3f278cf2ad

View File

@ -100,7 +100,16 @@ def unload_all(context: Context):
def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True):
model_names = model_name if isinstance(model_name, list) else [model_name]
model_paths = [resolve_model_to_use_single(m, model_type, fail_if_not_found) for m in model_names]
model_paths = []
for m in model_names:
if model_type == "embeddings":
try:
resolve_model_to_use_single(m, model_type, fail_if_not_found)
except FileNotFoundError: # try with spaces
m = m.replace("_", " ")
path = resolve_model_to_use_single(m, model_type, fail_if_not_found)
model_paths.append(path)
return model_paths[0] if len(model_paths) == 1 else model_paths
@ -139,7 +148,9 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
return default_model_path
if model_name and fail_if_not_found:
raise Exception(f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?")
raise FileNotFoundError(
f"Could not find the desired model {model_name}! Is it present in the {model_dir} folder?"
)
def reload_models_if_necessary(context: Context, models_data: ModelsData, models_to_force_reload: list = []):
@ -323,7 +334,7 @@ def getModels(scan_for_malicious: bool = True):
class MaliciousModelException(Exception):
"Raised when picklescan reports a problem with a model"
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None):
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter=None):
nonlocal models_scanned
tree = list(default_entries)
@ -366,7 +377,7 @@ def getModels(scan_for_malicious: bool = True):
tree.append((entry.name, scan))
return tree
def listModels(model_type, nameFilter = None):
def listModels(model_type, nameFilter=None):
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
@ -376,7 +387,9 @@ def getModels(scan_for_malicious: bool = True):
try:
default_tree = models["options"].get(model_type, [])
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter)
models["options"][model_type] = scan_directory(
models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
)
except MaliciousModelException as e:
models["scan-error"] = str(e)