diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index d88694e2..99c1c84c 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -11,6 +11,7 @@ from sdkit import Context from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.utils import hash_file_quick +from sdkit.models.model_loader.embeddings import get_embedding_token KNOWN_MODEL_TYPES = [ "stable-diffusion", @@ -322,9 +323,11 @@ def getModels(scan_for_malicious: bool = True): class MaliciousModelException(Exception): "Raised when picklescan reports a problem with a model" - def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]): - tree = list(default_entries) + def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None): nonlocal models_scanned + + tree = list(default_entries) + for entry in sorted( os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), @@ -343,7 +346,11 @@ def getModels(scan_for_malicious: bool = True): raise MaliciousModelException(entry.path) if scan_for_malicious: known_models[entry.path] = mtime + model_id = entry.name[: -len(matching_suffix)] + if callable(nameFilter): + model_id = nameFilter(model_id) + model_exists = False for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m): @@ -351,14 +358,15 @@ def getModels(scan_for_malicious: bool = True): break if not model_exists: tree.append(model_id) + elif entry.is_dir(): - scan = scan_directory(entry.path, suffixes, directoriesFirst=False) + scan = scan_directory(entry.path, suffixes, directoriesFirst=False, nameFilter=nameFilter) if len(scan) != 0: tree.append((entry.name, scan)) return tree - def listModels(model_type): + def listModels(model_type, nameFilter = None): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) @@ -368,7 +376,7 @@ def getModels(scan_for_malicious: bool = True): try: default_tree = models["options"].get(model_type, []) - models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree) + models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter) except MaliciousModelException as e: models["scan-error"] = str(e) @@ -380,7 +388,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="hypernetwork") listModels(model_type="gfpgan") listModels(model_type="lora") - listModels(model_type="embeddings") + listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="controlnet") if scan_for_malicious and models_scanned > 0: diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 465e2b8e..3bd5efba 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -10,6 +10,7 @@ from easydiffusion import app from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData from numpy import base_repr from sdkit.utils import save_dicts, save_images +from sdkit.models.model_loader.embeddings import get_embedding_token filename_regex = re.compile("[^a-zA-Z0-9._-]") img_number_regex = re.compile("([0-9]{5,})") @@ -228,7 +229,26 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output metadata[key] = req_metadata[key] elif key in task_data_metadata: metadata[key] = task_data_metadata[key] - + + if key == "use_embeddings_model" and using_diffusers: + embeddings_extensions = {".pt", ".bin", ".safetensors"} + + def scan_directory(directory_path: str): + used_embeddings = [] + for entry in os.scandir(directory_path): + if entry.is_file(): + # Check if the filename has the right extension + if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)): + continue + embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])") + if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt): + used_embeddings.append(entry.path) + elif entry.is_dir(): + used_embeddings.extend(scan_directory(entry.path)) + return used_embeddings + + used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings")) + metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None # Clean up the metadata if req.init_image is None and "prompt_strength" in metadata: