diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index f78d1164..abeece38 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -6,6 +6,7 @@ from threading import local import psutil from easydiffusion.app import ROOT_DIR, getConfig +from easydiffusion.model_manager import get_model_dir from . import impl from .impl import ( @@ -32,6 +33,18 @@ BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui")) SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") +MODELS_TO_OVERRIDE = { + "stable-diffusion": "--ckpt-dir", + "vae": "--vae-dir", + "hypernetwork": "--hypernetwork-dir", + "gfpgan": "--gfpgan-models-path", + "realesrgan": "--realesrgan-models-path", + "lora": "--lora-dir", + "codeformer": "--codeformer-models-path", + "embeddings": "--embeddings-dir", + "controlnet": "--controlnet-dir", +} + backend_process = None @@ -104,7 +117,8 @@ def get_env(): config = getConfig() models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) - embeddings_dir = os.path.join(models_dir, "embeddings") + + model_path_args = get_model_path_args() env_entries = { "PATH": [ @@ -125,7 +139,7 @@ def get_env(): "PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"], "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], - "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'], + "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args}'], "SKIP_VENV": ["1"], "SD_WEBUI_RESTARTING": ["1"], "PYTHON": [f"{dir}/python/python"], @@ -153,3 +167,12 @@ def kill(proc_pid): for proc in process.children(recursive=True): proc.kill() process.kill() + + +def get_model_path_args(): + args = [] + for model_type, flag in MODELS_TO_OVERRIDE.items(): + model_dir = get_model_dir(model_type) + args.append(f'{flag} "{model_dir}"') + + return " ".join(args) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index d821db41..2a20fce9 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -122,7 +122,7 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None, default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dir = os.path.join(app.MODELS_DIR, model_type) + model_dir = get_model_dir(model_type) if not model_name: # When None try user configured model. # config = getConfig() if "model" in config and model_type in config["model"]: @@ -239,7 +239,8 @@ def download_default_models_if_necessary(): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): - model_path = os.path.join(app.MODELS_DIR, model_type, file_name) + model_dir = get_model_dir(model_type) + model_path = os.path.join(model_dir, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] other_models_exist = any_model_exists(model_type) and skip_if_others_exist @@ -259,13 +260,15 @@ def migrate_legacy_model_location(): file_name = model["file_name"] legacy_path = os.path.join(app.SD_DIR, file_name) if os.path.exists(legacy_path): - shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name)) + model_dir = get_model_dir(model_type) + shutil.move(legacy_path, os.path.join(model_dir, file_name)) def any_model_exists(model_type: str) -> bool: extensions = MODEL_EXTENSIONS.get(model_type, []) + model_dir = get_model_dir(model_type) for ext in extensions: - if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)): + if any(glob(f"{model_dir}/**/*{ext}", recursive=True)): return True return False @@ -273,7 +276,7 @@ def any_model_exists(model_type: str) -> bool: def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: - model_dir_path = os.path.join(app.MODELS_DIR, model_type) + model_dir_path = get_model_dir(model_type) try: os.makedirs(model_dir_path, exist_ok=True) @@ -418,7 +421,7 @@ def getModels(scan_for_malicious: bool = True): nonlocal models_scanned model_extensions = MODEL_EXTENSIONS.get(model_type, []) - models_dir = os.path.join(app.MODELS_DIR, model_type) + models_dir = get_model_dir(model_type) if not os.path.exists(models_dir): os.makedirs(models_dir) @@ -445,3 +448,30 @@ def getModels(scan_for_malicious: bool = True): log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") return models + + +def get_model_dir(model_type: str, base_dir=None): + "Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)" + + if base_dir is None: + base_dir = app.MODELS_DIR + + for dir in os.listdir(base_dir): + if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)): + return os.path.join(base_dir, dir) + + return os.path.join(base_dir, model_type) + + +# patch sdkit +def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type): + "Patched version that works with case-insensitive model sub-dirs" + + download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir + download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir + return os.path.abspath(download_base_dir) + + +from sdkit.models import model_downloader + +model_downloader.get_actual_base_dir = __patched__get_actual_base_dir diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index ca7dc98e..f1b85764 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -364,15 +364,13 @@ def model_merge_internal(req: dict): mergeReq: MergeRequest = MergeRequest.parse_obj(req) + sd_model_dir = model_manager.get_model_dir("stable-diffusion") + merge_models( model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), mergeReq.ratio, - os.path.join( - app.MODELS_DIR, - "stable-diffusion", - filename_regex.sub("_", mergeReq.out_path), - ), + os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)), mergeReq.use_fp16, ) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) diff --git a/ui/easydiffusion/utils/nsfw_checker.py b/ui/easydiffusion/utils/nsfw_checker.py index 3790cacc..9e371a37 100644 --- a/ui/easydiffusion/utils/nsfw_checker.py +++ b/ui/easydiffusion/utils/nsfw_checker.py @@ -13,7 +13,7 @@ nsfw_check_model = None def filter_nsfw(images, blur_radius: float = 75, print_log=True): global nsfw_check_model - from easydiffusion.app import MODELS_DIR + from easydiffusion.model_manager import get_model_dir from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick import onnxruntime as ort @@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True): import numpy as np if nsfw_check_model is None: - model_dir = os.path.join(MODELS_DIR, "nsfw-checker") + model_dir = get_model_dir("nsfw-checker") model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") os.makedirs(model_dir, exist_ok=True) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index a597b281..c8624f40 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -102,7 +102,7 @@ var PARAMETERS = [ type: ParameterType.custom, icon: "fa-folder-tree", label: "Models Folder", - note: "Path to the 'models' folder. Please save and refresh the page after changing this.", + note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.", saveInAppConfig: true, render: (parameter) => { return ``