Support both WebUI and ED folder names for models

This commit is contained in:
cmdr2 2024-10-07 12:48:27 +05:30
parent 9abc76482c
commit b6ba782c35
4 changed files with 75 additions and 50 deletions

View File

@ -7,7 +7,7 @@ import psutil
import shutil
from easydiffusion.app import ROOT_DIR, getConfig
from easydiffusion.model_manager import get_model_dir
from easydiffusion.model_manager import get_model_dirs
from . import impl
from .impl import (
@ -74,6 +74,8 @@ def install_backend():
# create the conda env
run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR)
print("Installing packages..")
# install python 3.10 and git in the conda env
run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR)
@ -336,7 +338,7 @@ def kill(proc_pid):
def get_model_path_args():
args = []
for model_type, flag in MODELS_TO_OVERRIDE.items():
model_dir = get_model_dir(model_type)
model_dir = get_model_dirs(model_type)[0]
args.append(f'{flag} "{model_dir}"')
return " ".join(args)

View File

@ -51,6 +51,16 @@ DEFAULT_MODELS = {
],
}
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"]
ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility
"stable-diffusion": "Stable-diffusion",
"vae": "VAE",
"hypernetwork": "hypernetworks",
"codeformer": "Codeformer",
"gfpgan": "GFPGAN",
"realesrgan": "RealESRGAN",
"lora": "Lora",
"controlnet": "ControlNet",
}
known_models = {}
@ -122,33 +132,33 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig()
model_dir = get_model_dir(model_type)
if not model_name: # When None try user configured model.
# config = getConfig()
if "model" in config and model_type in config["model"]:
model_name = config["model"][model_type]
if model_name:
# Check models directory
model_path = os.path.join(model_dir, model_name)
if os.path.exists(model_path):
return model_path
for model_extension in model_extensions:
if os.path.exists(model_path + model_extension):
return model_path + model_extension
if os.path.exists(model_name + model_extension):
return os.path.abspath(model_name + model_extension)
for model_dir in get_model_dirs(model_type):
if model_name:
# Check models directory
model_path = os.path.join(model_dir, model_name)
if os.path.exists(model_path):
return model_path
for model_extension in model_extensions:
if os.path.exists(model_path + model_extension):
return model_path + model_extension
if os.path.exists(model_name + model_extension):
return os.path.abspath(model_name + model_extension)
# Can't find requested model, check the default paths.
if model_type == "stable-diffusion" and not fail_if_not_found:
for default_model in default_models:
default_model_path = os.path.join(model_dir, default_model["file_name"])
if os.path.exists(default_model_path):
if model_name is not None:
log.warn(
f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}"
)
return default_model_path
# Can't find requested model, check the default paths.
if model_type == "stable-diffusion" and not fail_if_not_found:
for default_model in default_models:
default_model_path = os.path.join(model_dir, default_model["file_name"])
if os.path.exists(default_model_path):
if model_name is not None:
log.warn(
f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}"
)
return default_model_path
if model_name and fail_if_not_found:
raise FileNotFoundError(
@ -239,7 +249,7 @@ def download_default_models_if_necessary():
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
model_dir = get_model_dir(model_type)
model_dir = get_model_dirs(model_type)[0]
model_path = os.path.join(model_dir, file_name)
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
@ -260,23 +270,23 @@ def migrate_legacy_model_location():
file_name = model["file_name"]
legacy_path = os.path.join(app.SD_DIR, file_name)
if os.path.exists(legacy_path):
model_dir = get_model_dir(model_type)
model_dir = get_model_dirs(model_type)[0]
shutil.move(legacy_path, os.path.join(model_dir, file_name))
def any_model_exists(model_type: str) -> bool:
extensions = MODEL_EXTENSIONS.get(model_type, [])
model_dir = get_model_dir(model_type)
for ext in extensions:
if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
return True
for model_dir in get_model_dirs(model_type):
for ext in extensions:
if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
return True
return False
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = get_model_dir(model_type)
model_dir_path = get_model_dirs(model_type)[0]
try:
os.makedirs(model_dir_path, exist_ok=True)
@ -377,6 +387,9 @@ def getModels(scan_for_malicious: bool = True):
tree = list(default_entries)
if not os.path.exists(directory):
return tree
for entry in sorted(
os.scandir(directory),
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
@ -421,17 +434,23 @@ def getModels(scan_for_malicious: bool = True):
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = get_model_dir(model_type)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
models_dirs = get_model_dirs(model_type)
if not os.path.exists(models_dirs[0]):
os.makedirs(models_dirs[0])
try:
default_tree = models["options"].get(model_type, [])
models["options"][model_type] = scan_directory(
models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
)
except MaliciousModelException as e:
models["scan-error"] = str(e)
models["options"][model_type] = []
default_tree = models["options"].get(model_type, [])
for model_dir in models_dirs:
try:
scanned_models = scan_directory(
model_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
)
for m in scanned_models:
if m not in models["options"][model_type]:
models["options"][model_type].append(m)
except MaliciousModelException as e:
models["scan-error"] = str(e)
if scan_for_malicious:
log.info(f"[green]Scanning all model folders for models...[/]")
@ -450,17 +469,21 @@ def getModels(scan_for_malicious: bool = True):
return models
def get_model_dir(model_type: str, base_dir=None):
"Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)"
def get_model_dirs(model_type: str, base_dir=None):
"Returns the possible model directory paths for the given model type. Mainly used for WebUI compatibility"
if base_dir is None:
base_dir = app.MODELS_DIR
for dir in os.listdir(base_dir):
if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)):
return os.path.join(base_dir, dir)
dirs = [os.path.join(base_dir, model_type)]
return os.path.join(base_dir, model_type)
if model_type in ALTERNATE_FOLDER_NAMES:
alt_dir = ALTERNATE_FOLDER_NAMES[model_type]
alt_dir = os.path.join(base_dir, alt_dir)
if os.path.exists(alt_dir) and os.path.isdir(alt_dir):
dirs.append(alt_dir)
return dirs
# patch sdkit
@ -468,7 +491,7 @@ def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_mod
"Patched version that works with case-insensitive model sub-dirs"
download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir
download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir
download_base_dir = get_model_dirs(model_type, download_base_dir)[0] if subdir_for_model_type else download_base_dir
return os.path.abspath(download_base_dir)

View File

@ -364,7 +364,7 @@ def model_merge_internal(req: dict):
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
sd_model_dir = model_manager.get_model_dir("stable-diffusion")
sd_model_dir = model_manager.get_model_dir("stable-diffusion")[0]
merge_models(
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),

View File

@ -13,7 +13,7 @@ nsfw_check_model = None
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
global nsfw_check_model
from easydiffusion.model_manager import get_model_dir
from easydiffusion.model_manager import get_model_dirs
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
import onnxruntime as ort
@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True):
import numpy as np
if nsfw_check_model is None:
model_dir = get_model_dir("nsfw-checker")
model_dir = get_model_dirs("nsfw-checker")[0]
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
os.makedirs(model_dir, exist_ok=True)