Case-insensitive model directories

This commit is contained in:
cmdr2 2024-10-01 13:55:35 +05:30
parent 9a12a8618c
commit 754a5f5e52
5 changed files with 67 additions and 16 deletions

View File

@ -6,6 +6,7 @@ from threading import local
import psutil import psutil
from easydiffusion.app import ROOT_DIR, getConfig from easydiffusion.app import ROOT_DIR, getConfig
from easydiffusion.model_manager import get_model_dir
from . import impl from . import impl
from .impl import ( from .impl import (
@ -32,6 +33,18 @@ BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui"))
SYSTEM_DIR = os.path.join(BACKEND_DIR, "system") SYSTEM_DIR = os.path.join(BACKEND_DIR, "system")
WEBUI_DIR = os.path.join(BACKEND_DIR, "webui") WEBUI_DIR = os.path.join(BACKEND_DIR, "webui")
MODELS_TO_OVERRIDE = {
"stable-diffusion": "--ckpt-dir",
"vae": "--vae-dir",
"hypernetwork": "--hypernetwork-dir",
"gfpgan": "--gfpgan-models-path",
"realesrgan": "--realesrgan-models-path",
"lora": "--lora-dir",
"codeformer": "--codeformer-models-path",
"embeddings": "--embeddings-dir",
"controlnet": "--controlnet-dir",
}
backend_process = None backend_process = None
@ -104,7 +117,8 @@ def get_env():
config = getConfig() config = getConfig()
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
embeddings_dir = os.path.join(models_dir, "embeddings")
model_path_args = get_model_path_args()
env_entries = { env_entries = {
"PATH": [ "PATH": [
@ -125,7 +139,7 @@ def get_env():
"PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"], "PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"],
"TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"], "TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"],
"HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"], "HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"],
"COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'], "COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args}'],
"SKIP_VENV": ["1"], "SKIP_VENV": ["1"],
"SD_WEBUI_RESTARTING": ["1"], "SD_WEBUI_RESTARTING": ["1"],
"PYTHON": [f"{dir}/python/python"], "PYTHON": [f"{dir}/python/python"],
@ -153,3 +167,12 @@ def kill(proc_pid):
for proc in process.children(recursive=True): for proc in process.children(recursive=True):
proc.kill() proc.kill()
process.kill() process.kill()
def get_model_path_args():
args = []
for model_type, flag in MODELS_TO_OVERRIDE.items():
model_dir = get_model_dir(model_type)
args.append(f'{flag} "{model_dir}"')
return " ".join(args)

View File

@ -122,7 +122,7 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
default_models = DEFAULT_MODELS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig() config = app.getConfig()
model_dir = os.path.join(app.MODELS_DIR, model_type) model_dir = get_model_dir(model_type)
if not model_name: # When None try user configured model. if not model_name: # When None try user configured model.
# config = getConfig() # config = getConfig()
if "model" in config and model_type in config["model"]: if "model" in config and model_type in config["model"]:
@ -239,7 +239,8 @@ def download_default_models_if_necessary():
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
model_path = os.path.join(app.MODELS_DIR, model_type, file_name) model_dir = get_model_dir(model_type)
model_path = os.path.join(model_dir, file_name)
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
other_models_exist = any_model_exists(model_type) and skip_if_others_exist other_models_exist = any_model_exists(model_type) and skip_if_others_exist
@ -259,13 +260,15 @@ def migrate_legacy_model_location():
file_name = model["file_name"] file_name = model["file_name"]
legacy_path = os.path.join(app.SD_DIR, file_name) legacy_path = os.path.join(app.SD_DIR, file_name)
if os.path.exists(legacy_path): if os.path.exists(legacy_path):
shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name)) model_dir = get_model_dir(model_type)
shutil.move(legacy_path, os.path.join(model_dir, file_name))
def any_model_exists(model_type: str) -> bool: def any_model_exists(model_type: str) -> bool:
extensions = MODEL_EXTENSIONS.get(model_type, []) extensions = MODEL_EXTENSIONS.get(model_type, [])
model_dir = get_model_dir(model_type)
for ext in extensions: for ext in extensions:
if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)): if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
return True return True
return False return False
@ -273,7 +276,7 @@ def any_model_exists(model_type: str) -> bool:
def make_model_folders(): def make_model_folders():
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type) model_dir_path = get_model_dir(model_type)
try: try:
os.makedirs(model_dir_path, exist_ok=True) os.makedirs(model_dir_path, exist_ok=True)
@ -418,7 +421,7 @@ def getModels(scan_for_malicious: bool = True):
nonlocal models_scanned nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, []) model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type) models_dir = get_model_dir(model_type)
if not os.path.exists(models_dir): if not os.path.exists(models_dir):
os.makedirs(models_dir) os.makedirs(models_dir)
@ -445,3 +448,30 @@ def getModels(scan_for_malicious: bool = True):
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
return models return models
def get_model_dir(model_type: str, base_dir=None):
"Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)"
if base_dir is None:
base_dir = app.MODELS_DIR
for dir in os.listdir(base_dir):
if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)):
return os.path.join(base_dir, dir)
return os.path.join(base_dir, model_type)
# patch sdkit
def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type):
"Patched version that works with case-insensitive model sub-dirs"
download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir
download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir
return os.path.abspath(download_base_dir)
from sdkit.models import model_downloader
model_downloader.get_actual_base_dir = __patched__get_actual_base_dir

View File

@ -364,15 +364,13 @@ def model_merge_internal(req: dict):
mergeReq: MergeRequest = MergeRequest.parse_obj(req) mergeReq: MergeRequest = MergeRequest.parse_obj(req)
sd_model_dir = model_manager.get_model_dir("stable-diffusion")
merge_models( merge_models(
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
mergeReq.ratio, mergeReq.ratio,
os.path.join( os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)),
app.MODELS_DIR,
"stable-diffusion",
filename_regex.sub("_", mergeReq.out_path),
),
mergeReq.use_fp16, mergeReq.use_fp16,
) )
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)

View File

@ -13,7 +13,7 @@ nsfw_check_model = None
def filter_nsfw(images, blur_radius: float = 75, print_log=True): def filter_nsfw(images, blur_radius: float = 75, print_log=True):
global nsfw_check_model global nsfw_check_model
from easydiffusion.app import MODELS_DIR from easydiffusion.model_manager import get_model_dir
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
import onnxruntime as ort import onnxruntime as ort
@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True):
import numpy as np import numpy as np
if nsfw_check_model is None: if nsfw_check_model is None:
model_dir = os.path.join(MODELS_DIR, "nsfw-checker") model_dir = get_model_dir("nsfw-checker")
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx") model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)

View File

@ -102,7 +102,7 @@ var PARAMETERS = [
type: ParameterType.custom, type: ParameterType.custom,
icon: "fa-folder-tree", icon: "fa-folder-tree",
label: "Models Folder", label: "Models Folder",
note: "Path to the 'models' folder. Please save and refresh the page after changing this.", note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.",
saveInAppConfig: true, saveInAppConfig: true,
render: (parameter) => { render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="30">` return `<input id="${parameter.id}" name="${parameter.id}" size="30">`