mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-23 06:39:50 +01:00
Case-insensitive model directories
This commit is contained in:
parent
9a12a8618c
commit
754a5f5e52
@ -6,6 +6,7 @@ from threading import local
|
|||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
from easydiffusion.app import ROOT_DIR, getConfig
|
from easydiffusion.app import ROOT_DIR, getConfig
|
||||||
|
from easydiffusion.model_manager import get_model_dir
|
||||||
|
|
||||||
from . import impl
|
from . import impl
|
||||||
from .impl import (
|
from .impl import (
|
||||||
@ -32,6 +33,18 @@ BACKEND_DIR = os.path.abspath(os.path.join(ROOT_DIR, "webui"))
|
|||||||
SYSTEM_DIR = os.path.join(BACKEND_DIR, "system")
|
SYSTEM_DIR = os.path.join(BACKEND_DIR, "system")
|
||||||
WEBUI_DIR = os.path.join(BACKEND_DIR, "webui")
|
WEBUI_DIR = os.path.join(BACKEND_DIR, "webui")
|
||||||
|
|
||||||
|
MODELS_TO_OVERRIDE = {
|
||||||
|
"stable-diffusion": "--ckpt-dir",
|
||||||
|
"vae": "--vae-dir",
|
||||||
|
"hypernetwork": "--hypernetwork-dir",
|
||||||
|
"gfpgan": "--gfpgan-models-path",
|
||||||
|
"realesrgan": "--realesrgan-models-path",
|
||||||
|
"lora": "--lora-dir",
|
||||||
|
"codeformer": "--codeformer-models-path",
|
||||||
|
"embeddings": "--embeddings-dir",
|
||||||
|
"controlnet": "--controlnet-dir",
|
||||||
|
}
|
||||||
|
|
||||||
backend_process = None
|
backend_process = None
|
||||||
|
|
||||||
|
|
||||||
@ -104,7 +117,8 @@ def get_env():
|
|||||||
|
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
|
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
|
||||||
embeddings_dir = os.path.join(models_dir, "embeddings")
|
|
||||||
|
model_path_args = get_model_path_args()
|
||||||
|
|
||||||
env_entries = {
|
env_entries = {
|
||||||
"PATH": [
|
"PATH": [
|
||||||
@ -125,7 +139,7 @@ def get_env():
|
|||||||
"PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"],
|
"PIP_INSTALLER_LOCATION": [f"{dir}/python/get-pip.py"],
|
||||||
"TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"],
|
"TRANSFORMERS_CACHE": [f"{dir}/transformers-cache"],
|
||||||
"HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"],
|
"HF_HUB_DISABLE_SYMLINKS_WARNING": ["true"],
|
||||||
"COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" --embeddings-dir "{embeddings_dir}"'],
|
"COMMANDLINE_ARGS": [f'--api --models-dir "{models_dir}" {model_path_args}'],
|
||||||
"SKIP_VENV": ["1"],
|
"SKIP_VENV": ["1"],
|
||||||
"SD_WEBUI_RESTARTING": ["1"],
|
"SD_WEBUI_RESTARTING": ["1"],
|
||||||
"PYTHON": [f"{dir}/python/python"],
|
"PYTHON": [f"{dir}/python/python"],
|
||||||
@ -153,3 +167,12 @@ def kill(proc_pid):
|
|||||||
for proc in process.children(recursive=True):
|
for proc in process.children(recursive=True):
|
||||||
proc.kill()
|
proc.kill()
|
||||||
process.kill()
|
process.kill()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path_args():
|
||||||
|
args = []
|
||||||
|
for model_type, flag in MODELS_TO_OVERRIDE.items():
|
||||||
|
model_dir = get_model_dir(model_type)
|
||||||
|
args.append(f'{flag} "{model_dir}"')
|
||||||
|
|
||||||
|
return " ".join(args)
|
||||||
|
@ -122,7 +122,7 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
|||||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
|
||||||
model_dir = os.path.join(app.MODELS_DIR, model_type)
|
model_dir = get_model_dir(model_type)
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
# config = getConfig()
|
# config = getConfig()
|
||||||
if "model" in config and model_type in config["model"]:
|
if "model" in config and model_type in config["model"]:
|
||||||
@ -239,7 +239,8 @@ def download_default_models_if_necessary():
|
|||||||
|
|
||||||
|
|
||||||
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
||||||
model_path = os.path.join(app.MODELS_DIR, model_type, file_name)
|
model_dir = get_model_dir(model_type)
|
||||||
|
model_path = os.path.join(model_dir, file_name)
|
||||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||||
|
|
||||||
other_models_exist = any_model_exists(model_type) and skip_if_others_exist
|
other_models_exist = any_model_exists(model_type) and skip_if_others_exist
|
||||||
@ -259,13 +260,15 @@ def migrate_legacy_model_location():
|
|||||||
file_name = model["file_name"]
|
file_name = model["file_name"]
|
||||||
legacy_path = os.path.join(app.SD_DIR, file_name)
|
legacy_path = os.path.join(app.SD_DIR, file_name)
|
||||||
if os.path.exists(legacy_path):
|
if os.path.exists(legacy_path):
|
||||||
shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name))
|
model_dir = get_model_dir(model_type)
|
||||||
|
shutil.move(legacy_path, os.path.join(model_dir, file_name))
|
||||||
|
|
||||||
|
|
||||||
def any_model_exists(model_type: str) -> bool:
|
def any_model_exists(model_type: str) -> bool:
|
||||||
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
|
model_dir = get_model_dir(model_type)
|
||||||
for ext in extensions:
|
for ext in extensions:
|
||||||
if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)):
|
if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@ -273,7 +276,7 @@ def any_model_exists(model_type: str) -> bool:
|
|||||||
|
|
||||||
def make_model_folders():
|
def make_model_folders():
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
model_dir_path = get_model_dir(model_type)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(model_dir_path, exist_ok=True)
|
os.makedirs(model_dir_path, exist_ok=True)
|
||||||
@ -418,7 +421,7 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
|
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
models_dir = os.path.join(app.MODELS_DIR, model_type)
|
models_dir = get_model_dir(model_type)
|
||||||
if not os.path.exists(models_dir):
|
if not os.path.exists(models_dir):
|
||||||
os.makedirs(models_dir)
|
os.makedirs(models_dir)
|
||||||
|
|
||||||
@ -445,3 +448,30 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_dir(model_type: str, base_dir=None):
|
||||||
|
"Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)"
|
||||||
|
|
||||||
|
if base_dir is None:
|
||||||
|
base_dir = app.MODELS_DIR
|
||||||
|
|
||||||
|
for dir in os.listdir(base_dir):
|
||||||
|
if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)):
|
||||||
|
return os.path.join(base_dir, dir)
|
||||||
|
|
||||||
|
return os.path.join(base_dir, model_type)
|
||||||
|
|
||||||
|
|
||||||
|
# patch sdkit
|
||||||
|
def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_model_type):
|
||||||
|
"Patched version that works with case-insensitive model sub-dirs"
|
||||||
|
|
||||||
|
download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir
|
||||||
|
download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir
|
||||||
|
return os.path.abspath(download_base_dir)
|
||||||
|
|
||||||
|
|
||||||
|
from sdkit.models import model_downloader
|
||||||
|
|
||||||
|
model_downloader.get_actual_base_dir = __patched__get_actual_base_dir
|
||||||
|
@ -364,15 +364,13 @@ def model_merge_internal(req: dict):
|
|||||||
|
|
||||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||||
|
|
||||||
|
sd_model_dir = model_manager.get_model_dir("stable-diffusion")
|
||||||
|
|
||||||
merge_models(
|
merge_models(
|
||||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||||
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||||
mergeReq.ratio,
|
mergeReq.ratio,
|
||||||
os.path.join(
|
os.path.join(sd_model_dir, filename_regex.sub("_", mergeReq.out_path)),
|
||||||
app.MODELS_DIR,
|
|
||||||
"stable-diffusion",
|
|
||||||
filename_regex.sub("_", mergeReq.out_path),
|
|
||||||
),
|
|
||||||
mergeReq.use_fp16,
|
mergeReq.use_fp16,
|
||||||
)
|
)
|
||||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||||
|
@ -13,7 +13,7 @@ nsfw_check_model = None
|
|||||||
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
||||||
global nsfw_check_model
|
global nsfw_check_model
|
||||||
|
|
||||||
from easydiffusion.app import MODELS_DIR
|
from easydiffusion.model_manager import get_model_dir
|
||||||
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
|
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if nsfw_check_model is None:
|
if nsfw_check_model is None:
|
||||||
model_dir = os.path.join(MODELS_DIR, "nsfw-checker")
|
model_dir = get_model_dir("nsfw-checker")
|
||||||
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
|
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
@ -102,7 +102,7 @@ var PARAMETERS = [
|
|||||||
type: ParameterType.custom,
|
type: ParameterType.custom,
|
||||||
icon: "fa-folder-tree",
|
icon: "fa-folder-tree",
|
||||||
label: "Models Folder",
|
label: "Models Folder",
|
||||||
note: "Path to the 'models' folder. Please save and refresh the page after changing this.",
|
note: "Path to the 'models' folder. Please save and restart Easy Diffusion after changing this.",
|
||||||
saveInAppConfig: true,
|
saveInAppConfig: true,
|
||||||
render: (parameter) => {
|
render: (parameter) => {
|
||||||
return `<input id="${parameter.id}" name="${parameter.id}" size="30">`
|
return `<input id="${parameter.id}" name="${parameter.id}" size="30">`
|
||||||
|
Loading…
Reference in New Issue
Block a user