forked from extern/easydiffusion
Fix handling of embeddings with space in their name (#1402)
* Fix handling of files with space in their name * Handle embeddings in save files * Moved get_embedding_token * Moved get_embedding_token * Update save_utils.py
This commit is contained in:
parent
31edce4a60
commit
cb7ba96dad
@ -11,6 +11,7 @@ from sdkit import Context
|
|||||||
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
|
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
|
||||||
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
|
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
|
||||||
from sdkit.utils import hash_file_quick
|
from sdkit.utils import hash_file_quick
|
||||||
|
from sdkit.models.model_loader.embeddings import get_embedding_token
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = [
|
KNOWN_MODEL_TYPES = [
|
||||||
"stable-diffusion",
|
"stable-diffusion",
|
||||||
@ -322,9 +323,11 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
class MaliciousModelException(Exception):
|
class MaliciousModelException(Exception):
|
||||||
"Raised when picklescan reports a problem with a model"
|
"Raised when picklescan reports a problem with a model"
|
||||||
|
|
||||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]):
|
def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[], nameFilter = None):
|
||||||
tree = list(default_entries)
|
|
||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
|
|
||||||
|
tree = list(default_entries)
|
||||||
|
|
||||||
for entry in sorted(
|
for entry in sorted(
|
||||||
os.scandir(directory),
|
os.scandir(directory),
|
||||||
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
||||||
@ -343,7 +346,11 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
raise MaliciousModelException(entry.path)
|
raise MaliciousModelException(entry.path)
|
||||||
if scan_for_malicious:
|
if scan_for_malicious:
|
||||||
known_models[entry.path] = mtime
|
known_models[entry.path] = mtime
|
||||||
|
|
||||||
model_id = entry.name[: -len(matching_suffix)]
|
model_id = entry.name[: -len(matching_suffix)]
|
||||||
|
if callable(nameFilter):
|
||||||
|
model_id = nameFilter(model_id)
|
||||||
|
|
||||||
model_exists = False
|
model_exists = False
|
||||||
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
|
for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models
|
||||||
if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m):
|
if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m):
|
||||||
@ -351,14 +358,15 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
break
|
break
|
||||||
if not model_exists:
|
if not model_exists:
|
||||||
tree.append(model_id)
|
tree.append(model_id)
|
||||||
|
|
||||||
elif entry.is_dir():
|
elif entry.is_dir():
|
||||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
scan = scan_directory(entry.path, suffixes, directoriesFirst=False, nameFilter=nameFilter)
|
||||||
|
|
||||||
if len(scan) != 0:
|
if len(scan) != 0:
|
||||||
tree.append((entry.name, scan))
|
tree.append((entry.name, scan))
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def listModels(model_type):
|
def listModels(model_type, nameFilter = None):
|
||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
|
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
@ -368,7 +376,7 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
default_tree = models["options"].get(model_type, [])
|
default_tree = models["options"].get(model_type, [])
|
||||||
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree)
|
models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter)
|
||||||
except MaliciousModelException as e:
|
except MaliciousModelException as e:
|
||||||
models["scan-error"] = str(e)
|
models["scan-error"] = str(e)
|
||||||
|
|
||||||
@ -380,7 +388,7 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
listModels(model_type="hypernetwork")
|
listModels(model_type="hypernetwork")
|
||||||
listModels(model_type="gfpgan")
|
listModels(model_type="gfpgan")
|
||||||
listModels(model_type="lora")
|
listModels(model_type="lora")
|
||||||
listModels(model_type="embeddings")
|
listModels(model_type="embeddings", nameFilter=get_embedding_token)
|
||||||
listModels(model_type="controlnet")
|
listModels(model_type="controlnet")
|
||||||
|
|
||||||
if scan_for_malicious and models_scanned > 0:
|
if scan_for_malicious and models_scanned > 0:
|
||||||
|
@ -10,6 +10,7 @@ from easydiffusion import app
|
|||||||
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
|
from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData
|
||||||
from numpy import base_repr
|
from numpy import base_repr
|
||||||
from sdkit.utils import save_dicts, save_images
|
from sdkit.utils import save_dicts, save_images
|
||||||
|
from sdkit.models.model_loader.embeddings import get_embedding_token
|
||||||
|
|
||||||
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||||
img_number_regex = re.compile("([0-9]{5,})")
|
img_number_regex = re.compile("([0-9]{5,})")
|
||||||
@ -228,7 +229,26 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output
|
|||||||
metadata[key] = req_metadata[key]
|
metadata[key] = req_metadata[key]
|
||||||
elif key in task_data_metadata:
|
elif key in task_data_metadata:
|
||||||
metadata[key] = task_data_metadata[key]
|
metadata[key] = task_data_metadata[key]
|
||||||
|
|
||||||
|
if key == "use_embeddings_model" and using_diffusers:
|
||||||
|
embeddings_extensions = {".pt", ".bin", ".safetensors"}
|
||||||
|
|
||||||
|
def scan_directory(directory_path: str):
|
||||||
|
used_embeddings = []
|
||||||
|
for entry in os.scandir(directory_path):
|
||||||
|
if entry.is_file():
|
||||||
|
# Check if the filename has the right extension
|
||||||
|
if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)):
|
||||||
|
continue
|
||||||
|
embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])")
|
||||||
|
if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt):
|
||||||
|
used_embeddings.append(entry.path)
|
||||||
|
elif entry.is_dir():
|
||||||
|
used_embeddings.extend(scan_directory(entry.path))
|
||||||
|
return used_embeddings
|
||||||
|
|
||||||
|
used_embeddings = scan_directory(os.path.join(app.MODELS_DIR, "embeddings"))
|
||||||
|
metadata["use_embeddings_model"] = used_embeddings if len(used_embeddings) > 0 else None
|
||||||
|
|
||||||
# Clean up the metadata
|
# Clean up the metadata
|
||||||
if req.init_image is None and "prompt_strength" in metadata:
|
if req.init_image is None and "prompt_strength" in metadata:
|
||||||
|
Loading…
Reference in New Issue
Block a user