Fix handling of embeddings with space in their name (#1402)

* Fix handling of files with space in their name

* Handle embeddings in save files

* Moved get_embedding_token

* Moved get_embedding_token

* Update save_utils.py
This commit is contained in:
JeLuF 2023-08-24 13:02:17 +02:00 committed by GitHub
parent 31edce4a60
commit cb7ba96dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 7 deletions

View File

@ -11,6 +11,7 @@ from sdkit import Context
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
from sdkit.utils import hash_file_quick from sdkit.utils import hash_file_quick
from sdkit.models.model_loader.embeddings import get_embedding_token
KNOWN_MODEL_TYPES = [ KNOWN_MODEL_TYPES = [
"stable-diffusion", "stable-diffusion",
@ -322,9 +323,11 @@ def getModels(scan_for_malicious: bool = True):
class MaliciousModelException(Exception): class MaliciousModelException(Exception):
"Raised when picklescan reports a problem with a model" "Raised when picklescan reports a problem with a model"
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]): def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None):
tree = list(default_entries)
nonlocal models_scanned nonlocal models_scanned
tree = list(default_entries)
for entry in sorted( for entry in sorted(
os.scandir(directory), os.scandir(directory),
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
@ -343,7 +346,11 @@ def getModels(scan_for_malicious: bool = True):
raise MaliciousModelException(entry.path) raise MaliciousModelException(entry.path)
if scan_for_malicious: if scan_for_malicious:
known_models[entry.path] = mtime known_models[entry.path] = mtime
model_id = entry.name[: -len(matching_suffix)] model_id = entry.name[: -len(matching_suffix)]
if callable(nameFilter):
model_id = nameFilter(model_id)
model_exists = False model_exists = False
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m): if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m):
@ -351,14 +358,15 @@ def getModels(scan_for_malicious: bool = True):
break break
if not model_exists: if not model_exists:
tree.append(model_id) tree.append(model_id)
elif entry.is_dir(): elif entry.is_dir():
scan = scan_directory(entry.path, suffixes, directoriesFirst=False) scan = scan_directory(entry.path, suffixes, directoriesFirst=False, nameFilter=nameFilter)
if len(scan) != 0: if len(scan) != 0:
tree.append((entry.name, scan)) tree.append((entry.name, scan))
return tree return tree
def listModels(model_type): def listModels(model_type, nameFilter = None):
nonlocal models_scanned nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, []) model_extensions = MODEL_EXTENSIONS.get(model_type, [])
@ -368,7 +376,7 @@ def getModels(scan_for_malicious: bool = True):
try: try:
default_tree = models["options"].get(model_type, []) default_tree = models["options"].get(model_type, [])
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree) models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter)
except MaliciousModelException as e: except MaliciousModelException as e:
models["scan-error"] = str(e) models["scan-error"] = str(e)
@ -380,7 +388,7 @@ def getModels(scan_for_malicious: bool = True):
listModels(model_type="hypernetwork") listModels(model_type="hypernetwork")
listModels(model_type="gfpgan") listModels(model_type="gfpgan")
listModels(model_type="lora") listModels(model_type="lora")
listModels(model_type="embeddings") listModels(model_type="embeddings", nameFilter=get_embedding_token)
listModels(model_type="controlnet") listModels(model_type="controlnet")
if scan_for_malicious and models_scanned > 0: if scan_for_malicious and models_scanned > 0:

View File

@ -10,6 +10,7 @@ from easydiffusion import app
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
from numpy import base_repr from numpy import base_repr
from sdkit.utils import save_dicts, save_images from sdkit.utils import save_dicts, save_images
from sdkit.models.model_loader.embeddings import get_embedding_token
filename_regex = re.compile("[^a-zA-Z0-9._-]") filename_regex = re.compile("[^a-zA-Z0-9._-]")
img_number_regex = re.compile("([0-9]{5,})") img_number_regex = re.compile("([0-9]{5,})")
@ -228,7 +229,26 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output
metadata[key] = req_metadata[key] metadata[key] = req_metadata[key]
elif key in task_data_metadata: elif key in task_data_metadata:
metadata[key] = task_data_metadata[key] metadata[key] = task_data_metadata[key]
if key == "use_embeddings_model" and using_diffusers:
embeddings_extensions = {".pt", ".bin", ".safetensors"}
def scan_directory(directory_path: str):
used_embeddings = []
for entry in os.scandir(directory_path):
if entry.is_file():
# Check if the filename has the right extension
if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)):
continue
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])")
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
used_embeddings.append(entry.path)
elif entry.is_dir():
used_embeddings.extend(scan_directory(entry.path))
return used_embeddings
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None
# Clean up the metadata # Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata: if req.init_image is None and "prompt_strength" in metadata: