From b6ba782c35124bfa30de749211c7673da93faaa1 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 7 Oct 2024 12:48:27 +0530 Subject: [PATCH] Support both WebUI and ED folder names for models --- ui/easydiffusion/backends/webui/__init__.py | 6 +- ui/easydiffusion/model_manager.py | 113 ++++++++++++-------- ui/easydiffusion/server.py | 2 +- ui/easydiffusion/utils/nsfw_checker.py | 4 +- 4 files changed, 75 insertions(+), 50 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index 95cd7883..6b21ae1c 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -7,7 +7,7 @@ import psutil import shutil from easydiffusion.app import ROOT_DIR, getConfig -from easydiffusion.model_manager import get_model_dir +from easydiffusion.model_manager import get_model_dirs from . import impl from .impl import ( @@ -74,6 +74,8 @@ def install_backend(): # create the conda env run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR) + print("Installing packages..") + # install python 3.10 and git in the conda env run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR) @@ -336,7 +338,7 @@ def kill(proc_pid): def get_model_path_args(): args = [] for model_type, flag in MODELS_TO_OVERRIDE.items(): - model_dir = get_model_dir(model_type) + model_dir = get_model_dirs(model_type)[0] args.append(f'{flag} "{model_dir}"') return " ".join(args) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 2a20fce9..60c286c9 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -51,6 +51,16 @@ DEFAULT_MODELS = { ], } MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] +ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility + "stable-diffusion": "Stable-diffusion", + "vae": "VAE", + "hypernetwork": "hypernetworks", + "codeformer": "Codeformer", + "gfpgan": "GFPGAN", + "realesrgan": "RealESRGAN", + "lora": "Lora", + "controlnet": "ControlNet", +} known_models = {} @@ -122,33 +132,33 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dir = get_model_dir(model_type) if not model_name: # When None try user configured model. # config = getConfig() if "model" in config and model_type in config["model"]: model_name = config["model"][model_type] - if model_name: - # Check models directory - model_path = os.path.join(model_dir, model_name) - if os.path.exists(model_path): - return model_path - for model_extension in model_extensions: - if os.path.exists(model_path + model_extension): - return model_path + model_extension - if os.path.exists(model_name + model_extension): - return os.path.abspath(model_name + model_extension) + for model_dir in get_model_dirs(model_type): + if model_name: + # Check models directory + model_path = os.path.join(model_dir, model_name) + if os.path.exists(model_path): + return model_path + for model_extension in model_extensions: + if os.path.exists(model_path + model_extension): + return model_path + model_extension + if os.path.exists(model_name + model_extension): + return os.path.abspath(model_name + model_extension) - # Can't find requested model, check the default paths. - if model_type == "stable-diffusion" and not fail_if_not_found: - for default_model in default_models: - default_model_path = os.path.join(model_dir, default_model["file_name"]) - if os.path.exists(default_model_path): - if model_name is not None: - log.warn( - f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" - ) - return default_model_path + # Can't find requested model, check the default paths. + if model_type == "stable-diffusion" and not fail_if_not_found: + for default_model in default_models: + default_model_path = os.path.join(model_dir, default_model["file_name"]) + if os.path.exists(default_model_path): + if model_name is not None: + log.warn( + f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" + ) + return default_model_path if model_name and fail_if_not_found: raise FileNotFoundError( @@ -239,7 +249,7 @@ def download_default_models_if_necessary(): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): - model_dir = get_model_dir(model_type) + model_dir = get_model_dirs(model_type)[0] model_path = os.path.join(model_dir, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] @@ -260,23 +270,23 @@ def migrate_legacy_model_location(): file_name = model["file_name"] legacy_path = os.path.join(app.SD_DIR, file_name) if os.path.exists(legacy_path): - model_dir = get_model_dir(model_type) + model_dir = get_model_dirs(model_type)[0] shutil.move(legacy_path, os.path.join(model_dir, file_name)) def any_model_exists(model_type: str) -> bool: extensions = MODEL_EXTENSIONS.get(model_type, []) - model_dir = get_model_dir(model_type) - for ext in extensions: - if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): - return True + for model_dir in get_model_dirs(model_type): + for ext in extensions: + if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): + return True return False def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: - model_dir_path = get_model_dir(model_type) + model_dir_path = get_model_dirs(model_type)[0] try: os.makedirs(model_dir_path, exist_ok=True) @@ -377,6 +387,9 @@ def getModels(scan_for_malicious: bool = True): tree = list(default_entries) + if not os.path.exists(directory): + return tree + for entry in sorted( os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), @@ -421,17 +434,23 @@ def getModels(scan_for_malicious: bool = True): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) - models_dir = get_model_dir(model_type) - if not os.path.exists(models_dir): - os.makedirs(models_dir) + models_dirs = get_model_dirs(model_type) + if not os.path.exists(models_dirs[0]): + os.makedirs(models_dirs[0]) - try: - default_tree = models["options"].get(model_type, []) - models["options"][model_type] = scan_directory( - models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter - ) - except MaliciousModelException as e: - models["scan-error"] = str(e) + models["options"][model_type] = [] + default_tree = models["options"].get(model_type, []) + + for model_dir in models_dirs: + try: + scanned_models = scan_directory( + model_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter + ) + for m in scanned_models: + if m not in models["options"][model_type]: + models["options"][model_type].append(m) + except MaliciousModelException as e: + models["scan-error"] = str(e) if scan_for_malicious: log.info(f"[green]Scanning all model folders for models...[/]") @@ -450,17 +469,21 @@ def getModels(scan_for_malicious: bool = True): return models -def get_model_dir(model_type: str, base_dir=None): - "Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)" +def get_model_dirs(model_type: str, base_dir=None): + "Returns the possible model directory paths for the given model type. Mainly used for WebUI compatibility" if base_dir is None: base_dir = app.MODELS_DIR - for dir in os.listdir(base_dir): - if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)): - return os.path.join(base_dir, dir) + dirs = [os.path.join(base_dir, model_type)] - return os.path.join(base_dir, model_type) + if model_type in ALTERNATE_FOLDER_NAMES: + alt_dir = ALTERNATE_FOLDER_NAMES[model_type] + alt_dir = os.path.join(base_dir, alt_dir) + if os.path.exists(alt_dir) and os.path.isdir(alt_dir): + dirs.append(alt_dir) + + return dirs # patch sdkit @@ -468,7 +491,7 @@ def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_mod "Patched version that works with case-insensitive model sub-dirs" download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir - download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir + download_base_dir = get_model_dirs(model_type, download_base_dir)[0] if subdir_for_model_type else download_base_dir return os.path.abspath(download_base_dir) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index f1b85764..63d940aa 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -364,7 +364,7 @@ def model_merge_internal(req: dict): mergeReq: MergeRequest = MergeRequest.parse_obj(req) - sd_model_dir = model_manager.get_model_dir("stable-diffusion") + sd_model_dir = model_manager.get_model_dir("stable-diffusion")[0] merge_models( model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py index 9e371a37..51a684df 100644 --- a/ui/easydiffusion/utils/nsfw_checker.py +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -13,7 +13,7 @@ nsfw_check_model = None def filter_nsfw(images, blur_radius: float = 75, print_log=True): global nsfw_check_model - from easydiffusion.model_manager import get_model_dir + from easydiffusion.model_manager import get_model_dirs from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick import onnxruntime as ort @@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True): import numpy as np if nsfw_check_model is None: - model_dir = get_model_dir("nsfw-checker") + model_dir = get_model_dirs("nsfw-checker")[0] model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") os.makedirs(model_dir, exist_ok=True)