mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-07 14:59:32 +01:00
Case-insensitive model directories
This commit is contained in:
parent
9a12a8618c
commit
754a5f5e52
@ -6,6 +6,7 @@ from threading import local
|
||||
import psutil
|
||||
|
||||
from easydiffusion.app import ROOT_DIR, getConfig
|
||||
from easydiffusion.model_manager import get_model_dir
|
||||
|
||||
from . import impl
|
||||
from .impl import (
|
||||
@ -32,6 +33,18 @@ BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui"))
|
||||
SYSTEM_DIR = os.path.join(BACKEND_DIR, "system")
|
||||
WEBUI_DIR = os.path.join(BACKEND_DIR, "webui")
|
||||
|
||||
MODELS_TO_OVERRIDE = {
|
||||
"stable-diffusion": "--ckpt-dir",
|
||||
"vae": "--vae-dir",
|
||||
"hypernetwork": "--hypernetwork-dir",
|
||||
"gfpgan": "--gfpgan-models-path",
|
||||
"realesrgan": "--realesrgan-models-path",
|
||||
"lora": "--lora-dir",
|
||||
"codeformer": "--codeformer-models-path",
|
||||
"embeddings": "--embeddings-dir",
|
||||
"controlnet": "--controlnet-dir",
|
||||
}
|
||||
|
||||
backend_process = None
|
||||
|
||||
|
||||
@ -104,7 +117,8 @@ def get_env():
|
||||
|
||||
config = getConfig()
|
||||
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
|
||||
embeddings_dir = os.path.join(models_dir, "embeddings")
|
||||
|
||||
model_path_args = get_model_path_args()
|
||||
|
||||
env_entries = {
|
||||
"PATH": [
|
||||
@ -125,7 +139,7 @@ def get_env():
|
||||
"PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"],
|
||||
"TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"],
|
||||
"HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"],
|
||||
"COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'],
|
||||
"COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args}'],
|
||||
"SKIP_VENV": ["1"],
|
||||
"SD_WEBUI_RESTARTING": ["1"],
|
||||
"PYTHON": [f"{dir}/python/python"],
|
||||
@ -153,3 +167,12 @@ def kill(proc_pid):
|
||||
for proc in process.children(recursive=True):
|
||||
proc.kill()
|
||||
process.kill()
|
||||
|
||||
|
||||
def get_model_path_args():
|
||||
args = []
|
||||
for model_type, flag in MODELS_TO_OVERRIDE.items():
|
||||
model_dir = get_model_dir(model_type)
|
||||
args.append(f'{flag} "{model_dir}"')
|
||||
|
||||
return " ".join(args)
|
||||
|
@ -122,7 +122,7 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||
config = app.getConfig()
|
||||
|
||||
model_dir = os.path.join(app.MODELS_DIR, model_type)
|
||||
model_dir = get_model_dir(model_type)
|
||||
if not model_name: # When None try user configured model.
|
||||
# config = getConfig()
|
||||
if "model" in config and model_type in config["model"]:
|
||||
@ -239,7 +239,8 @@ def download_default_models_if_necessary():
|
||||
|
||||
|
||||
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
||||
model_path = os.path.join(app.MODELS_DIR, model_type, file_name)
|
||||
model_dir = get_model_dir(model_type)
|
||||
model_path = os.path.join(model_dir, file_name)
|
||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||
|
||||
other_models_exist = any_model_exists(model_type) and skip_if_others_exist
|
||||
@ -259,13 +260,15 @@ def migrate_legacy_model_location():
|
||||
file_name = model["file_name"]
|
||||
legacy_path = os.path.join(app.SD_DIR, file_name)
|
||||
if os.path.exists(legacy_path):
|
||||
shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name))
|
||||
model_dir = get_model_dir(model_type)
|
||||
shutil.move(legacy_path, os.path.join(model_dir, file_name))
|
||||
|
||||
|
||||
def any_model_exists(model_type: str) -> bool:
|
||||
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
model_dir = get_model_dir(model_type)
|
||||
for ext in extensions:
|
||||
if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)):
|
||||
if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -273,7 +276,7 @@ def any_model_exists(model_type: str) -> bool:
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
model_dir_path = get_model_dir(model_type)
|
||||
|
||||
try:
|
||||
os.makedirs(model_dir_path, exist_ok=True)
|
||||
@ -418,7 +421,7 @@ def getModels(scan_for_malicious: bool = True):
|
||||
nonlocal models_scanned
|
||||
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
models_dir = os.path.join(app.MODELS_DIR, model_type)
|
||||
models_dir = get_model_dir(model_type)
|
||||
if not os.path.exists(models_dir):
|
||||
os.makedirs(models_dir)
|
||||
|
||||
@ -445,3 +448,30 @@ def getModels(scan_for_malicious: bool = True):
|
||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def get_model_dir(model_type: str, base_dir=None):
|
||||
"Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)"
|
||||
|
||||
if base_dir is None:
|
||||
base_dir = app.MODELS_DIR
|
||||
|
||||
for dir in os.listdir(base_dir):
|
||||
if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)):
|
||||
return os.path.join(base_dir, dir)
|
||||
|
||||
return os.path.join(base_dir, model_type)
|
||||
|
||||
|
||||
# patch sdkit
|
||||
def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type):
|
||||
"Patched version that works with case-insensitive model sub-dirs"
|
||||
|
||||
download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir
|
||||
download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir
|
||||
return os.path.abspath(download_base_dir)
|
||||
|
||||
|
||||
from sdkit.models import model_downloader
|
||||
|
||||
model_downloader.get_actual_base_dir = __patched__get_actual_base_dir
|
||||
|
@ -364,15 +364,13 @@ def model_merge_internal(req: dict):
|
||||
|
||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||
|
||||
sd_model_dir = model_manager.get_model_dir("stable-diffusion")
|
||||
|
||||
merge_models(
|
||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||
mergeReq.ratio,
|
||||
os.path.join(
|
||||
app.MODELS_DIR,
|
||||
"stable-diffusion",
|
||||
filename_regex.sub("_", mergeReq.out_path),
|
||||
),
|
||||
os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)),
|
||||
mergeReq.use_fp16,
|
||||
)
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
|
@ -13,7 +13,7 @@ nsfw_check_model = None
|
||||
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
||||
global nsfw_check_model
|
||||
|
||||
from easydiffusion.app import MODELS_DIR
|
||||
from easydiffusion.model_manager import get_model_dir
|
||||
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
|
||||
|
||||
import onnxruntime as ort
|
||||
@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
||||
import numpy as np
|
||||
|
||||
if nsfw_check_model is None:
|
||||
model_dir = os.path.join(MODELS_DIR, "nsfw-checker")
|
||||
model_dir = get_model_dir("nsfw-checker")
|
||||
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
@ -102,7 +102,7 @@ var PARAMETERS = [
|
||||
type: ParameterType.custom,
|
||||
icon: "fa-folder-tree",
|
||||
label: "Models Folder",
|
||||
note: "Path to the 'models' folder. Please save and refresh the page after changing this.",
|
||||
note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.",
|
||||
saveInAppConfig: true,
|
||||
render: (parameter) => {
|
||||
return `<input id="${parameter.id}" name="${parameter.id}" size="30">`
|
||||
|
Loading…
Reference in New Issue
Block a user