mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-24 17:24:29 +01:00
Fix handling of embeddings with space in their name (#1402)
* Fix handling of files with space in their name * Handle embeddings in save files * Moved get_embedding_token * Moved get_embedding_token * Update save_utils.py
This commit is contained in:
parent
31edce4a60
commit
cb7ba96dad
@ -11,6 +11,7 @@ from sdkit import Context
|
||||
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
|
||||
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
|
||||
from sdkit.utils import hash_file_quick
|
||||
from sdkit.models.model_loader.embeddings import get_embedding_token
|
||||
|
||||
KNOWN_MODEL_TYPES = [
|
||||
"stable-diffusion",
|
||||
@ -322,9 +323,11 @@ def getModels(scan_for_malicious: bool = True):
|
||||
class MaliciousModelException(Exception):
|
||||
"Raised when picklescan reports a problem with a model"
|
||||
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]):
|
||||
tree = list(default_entries)
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None):
|
||||
nonlocal models_scanned
|
||||
|
||||
tree = list(default_entries)
|
||||
|
||||
for entry in sorted(
|
||||
os.scandir(directory),
|
||||
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
||||
@ -343,7 +346,11 @@ def getModels(scan_for_malicious: bool = True):
|
||||
raise MaliciousModelException(entry.path)
|
||||
if scan_for_malicious:
|
||||
known_models[entry.path] = mtime
|
||||
|
||||
model_id = entry.name[: -len(matching_suffix)]
|
||||
if callable(nameFilter):
|
||||
model_id = nameFilter(model_id)
|
||||
|
||||
model_exists = False
|
||||
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
|
||||
if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m):
|
||||
@ -351,14 +358,15 @@ def getModels(scan_for_malicious: bool = True):
|
||||
break
|
||||
if not model_exists:
|
||||
tree.append(model_id)
|
||||
|
||||
elif entry.is_dir():
|
||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False, nameFilter=nameFilter)
|
||||
|
||||
if len(scan) != 0:
|
||||
tree.append((entry.name, scan))
|
||||
return tree
|
||||
|
||||
def listModels(model_type):
|
||||
def listModels(model_type, nameFilter = None):
|
||||
nonlocal models_scanned
|
||||
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
@ -368,7 +376,7 @@ def getModels(scan_for_malicious: bool = True):
|
||||
|
||||
try:
|
||||
default_tree = models["options"].get(model_type, [])
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree)
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter)
|
||||
except MaliciousModelException as e:
|
||||
models["scan-error"] = str(e)
|
||||
|
||||
@ -380,7 +388,7 @@ def getModels(scan_for_malicious: bool = True):
|
||||
listModels(model_type="hypernetwork")
|
||||
listModels(model_type="gfpgan")
|
||||
listModels(model_type="lora")
|
||||
listModels(model_type="embeddings")
|
||||
listModels(model_type="embeddings", nameFilter=get_embedding_token)
|
||||
listModels(model_type="controlnet")
|
||||
|
||||
if scan_for_malicious and models_scanned > 0:
|
||||
|
@ -10,6 +10,7 @@ from easydiffusion import app
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
|
||||
from numpy import base_repr
|
||||
from sdkit.utils import save_dicts, save_images
|
||||
from sdkit.models.model_loader.embeddings import get_embedding_token
|
||||
|
||||
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||
img_number_regex = re.compile("([0-9]{5,})")
|
||||
@ -228,7 +229,26 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output
|
||||
metadata[key] = req_metadata[key]
|
||||
elif key in task_data_metadata:
|
||||
metadata[key] = task_data_metadata[key]
|
||||
|
||||
|
||||
if key == "use_embeddings_model" and using_diffusers:
|
||||
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
||||
|
||||
def scan_directory(directory_path: str):
|
||||
used_embeddings = []
|
||||
for entry in os.scandir(directory_path):
|
||||
if entry.is_file():
|
||||
# Check if the filename has the right extension
|
||||
if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)):
|
||||
continue
|
||||
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])")
|
||||
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
||||
used_embeddings.append(entry.path)
|
||||
elif entry.is_dir():
|
||||
used_embeddings.extend(scan_directory(entry.path))
|
||||
return used_embeddings
|
||||
|
||||
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
||||
metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None
|
||||
|
||||
# Clean up the metadata
|
||||
if req.init_image is None and "prompt_strength" in metadata:
|
||||
|
Loading…
Reference in New Issue
Block a user