mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-04 00:56:23 +02:00
Support both WebUI and ED folder names for models
This commit is contained in:
parent
9abc76482c
commit
b6ba782c35
@ -7,7 +7,7 @@ import psutil
|
|||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from easydiffusion.app import ROOT_DIR, getConfig
|
from easydiffusion.app import ROOT_DIR, getConfig
|
||||||
from easydiffusion.model_manager import get_model_dir
|
from easydiffusion.model_manager import get_model_dirs
|
||||||
|
|
||||||
from . import impl
|
from . import impl
|
||||||
from .impl import (
|
from .impl import (
|
||||||
@ -74,6 +74,8 @@ def install_backend():
|
|||||||
# create the conda env
|
# create the conda env
|
||||||
run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR)
|
run([conda, "create", "-y", "--prefix", SYSTEM_DIR], cwd=ROOT_DIR)
|
||||||
|
|
||||||
|
print("Installing packages..")
|
||||||
|
|
||||||
# install python 3.10 and git in the conda env
|
# install python 3.10 and git in the conda env
|
||||||
run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR)
|
run([conda, "install", "-y", "--prefix", SYSTEM_DIR, "-c", "conda-forge", "python=3.10", "git"], cwd=ROOT_DIR)
|
||||||
|
|
||||||
@ -336,7 +338,7 @@ def kill(proc_pid):
|
|||||||
def get_model_path_args():
|
def get_model_path_args():
|
||||||
args = []
|
args = []
|
||||||
for model_type, flag in MODELS_TO_OVERRIDE.items():
|
for model_type, flag in MODELS_TO_OVERRIDE.items():
|
||||||
model_dir = get_model_dir(model_type)
|
model_dir = get_model_dirs(model_type)[0]
|
||||||
args.append(f'{flag} "{model_dir}"')
|
args.append(f'{flag} "{model_dir}"')
|
||||||
|
|
||||||
return " ".join(args)
|
return " ".join(args)
|
||||||
|
@ -51,6 +51,16 @@ DEFAULT_MODELS = {
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"]
|
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"]
|
||||||
|
ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility
|
||||||
|
"stable-diffusion": "Stable-diffusion",
|
||||||
|
"vae": "VAE",
|
||||||
|
"hypernetwork": "hypernetworks",
|
||||||
|
"codeformer": "Codeformer",
|
||||||
|
"gfpgan": "GFPGAN",
|
||||||
|
"realesrgan": "RealESRGAN",
|
||||||
|
"lora": "Lora",
|
||||||
|
"controlnet": "ControlNet",
|
||||||
|
}
|
||||||
|
|
||||||
known_models = {}
|
known_models = {}
|
||||||
|
|
||||||
@ -122,33 +132,33 @@ def resolve_model_to_use_single(model_name: str = None, model_type: str = None,
|
|||||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
|
||||||
model_dir = get_model_dir(model_type)
|
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
# config = getConfig()
|
# config = getConfig()
|
||||||
if "model" in config and model_type in config["model"]:
|
if "model" in config and model_type in config["model"]:
|
||||||
model_name = config["model"][model_type]
|
model_name = config["model"][model_type]
|
||||||
|
|
||||||
if model_name:
|
for model_dir in get_model_dirs(model_type):
|
||||||
# Check models directory
|
if model_name:
|
||||||
model_path = os.path.join(model_dir, model_name)
|
# Check models directory
|
||||||
if os.path.exists(model_path):
|
model_path = os.path.join(model_dir, model_name)
|
||||||
return model_path
|
if os.path.exists(model_path):
|
||||||
for model_extension in model_extensions:
|
return model_path
|
||||||
if os.path.exists(model_path + model_extension):
|
for model_extension in model_extensions:
|
||||||
return model_path + model_extension
|
if os.path.exists(model_path + model_extension):
|
||||||
if os.path.exists(model_name + model_extension):
|
return model_path + model_extension
|
||||||
return os.path.abspath(model_name + model_extension)
|
if os.path.exists(model_name + model_extension):
|
||||||
|
return os.path.abspath(model_name + model_extension)
|
||||||
|
|
||||||
# Can't find requested model, check the default paths.
|
# Can't find requested model, check the default paths.
|
||||||
if model_type == "stable-diffusion" and not fail_if_not_found:
|
if model_type == "stable-diffusion" and not fail_if_not_found:
|
||||||
for default_model in default_models:
|
for default_model in default_models:
|
||||||
default_model_path = os.path.join(model_dir, default_model["file_name"])
|
default_model_path = os.path.join(model_dir, default_model["file_name"])
|
||||||
if os.path.exists(default_model_path):
|
if os.path.exists(default_model_path):
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
log.warn(
|
log.warn(
|
||||||
f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}"
|
f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}"
|
||||||
)
|
)
|
||||||
return default_model_path
|
return default_model_path
|
||||||
|
|
||||||
if model_name and fail_if_not_found:
|
if model_name and fail_if_not_found:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
@ -239,7 +249,7 @@ def download_default_models_if_necessary():
|
|||||||
|
|
||||||
|
|
||||||
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
||||||
model_dir = get_model_dir(model_type)
|
model_dir = get_model_dirs(model_type)[0]
|
||||||
model_path = os.path.join(model_dir, file_name)
|
model_path = os.path.join(model_dir, file_name)
|
||||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||||
|
|
||||||
@ -260,23 +270,23 @@ def migrate_legacy_model_location():
|
|||||||
file_name = model["file_name"]
|
file_name = model["file_name"]
|
||||||
legacy_path = os.path.join(app.SD_DIR, file_name)
|
legacy_path = os.path.join(app.SD_DIR, file_name)
|
||||||
if os.path.exists(legacy_path):
|
if os.path.exists(legacy_path):
|
||||||
model_dir = get_model_dir(model_type)
|
model_dir = get_model_dirs(model_type)[0]
|
||||||
shutil.move(legacy_path, os.path.join(model_dir, file_name))
|
shutil.move(legacy_path, os.path.join(model_dir, file_name))
|
||||||
|
|
||||||
|
|
||||||
def any_model_exists(model_type: str) -> bool:
|
def any_model_exists(model_type: str) -> bool:
|
||||||
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
model_dir = get_model_dir(model_type)
|
for model_dir in get_model_dirs(model_type):
|
||||||
for ext in extensions:
|
for ext in extensions:
|
||||||
if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
|
if any(glob(f"{model_dir}/**/*{ext}", recursive=True)):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def make_model_folders():
|
def make_model_folders():
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
model_dir_path = get_model_dir(model_type)
|
model_dir_path = get_model_dirs(model_type)[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(model_dir_path, exist_ok=True)
|
os.makedirs(model_dir_path, exist_ok=True)
|
||||||
@ -377,6 +387,9 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
|
|
||||||
tree = list(default_entries)
|
tree = list(default_entries)
|
||||||
|
|
||||||
|
if not os.path.exists(directory):
|
||||||
|
return tree
|
||||||
|
|
||||||
for entry in sorted(
|
for entry in sorted(
|
||||||
os.scandir(directory),
|
os.scandir(directory),
|
||||||
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
||||||
@ -421,17 +434,23 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
|
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
models_dir = get_model_dir(model_type)
|
models_dirs = get_model_dirs(model_type)
|
||||||
if not os.path.exists(models_dir):
|
if not os.path.exists(models_dirs[0]):
|
||||||
os.makedirs(models_dir)
|
os.makedirs(models_dirs[0])
|
||||||
|
|
||||||
try:
|
models["options"][model_type] = []
|
||||||
default_tree = models["options"].get(model_type, [])
|
default_tree = models["options"].get(model_type, [])
|
||||||
models["options"][model_type] = scan_directory(
|
|
||||||
models_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
|
for model_dir in models_dirs:
|
||||||
)
|
try:
|
||||||
except MaliciousModelException as e:
|
scanned_models = scan_directory(
|
||||||
models["scan-error"] = str(e)
|
model_dir, model_extensions, default_entries=default_tree, nameFilter=nameFilter
|
||||||
|
)
|
||||||
|
for m in scanned_models:
|
||||||
|
if m not in models["options"][model_type]:
|
||||||
|
models["options"][model_type].append(m)
|
||||||
|
except MaliciousModelException as e:
|
||||||
|
models["scan-error"] = str(e)
|
||||||
|
|
||||||
if scan_for_malicious:
|
if scan_for_malicious:
|
||||||
log.info(f"[green]Scanning all model folders for models...[/]")
|
log.info(f"[green]Scanning all model folders for models...[/]")
|
||||||
@ -450,17 +469,21 @@ def getModels(scan_for_malicious: bool = True):
|
|||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
def get_model_dir(model_type: str, base_dir=None):
|
def get_model_dirs(model_type: str, base_dir=None):
|
||||||
"Returns the case-insensitive model directory path, or the given model folder (if the model sub-dir wasn't found)"
|
"Returns the possible model directory paths for the given model type. Mainly used for WebUI compatibility"
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
base_dir = app.MODELS_DIR
|
base_dir = app.MODELS_DIR
|
||||||
|
|
||||||
for dir in os.listdir(base_dir):
|
dirs = [os.path.join(base_dir, model_type)]
|
||||||
if dir.lower() == model_type.lower() and os.path.isdir(os.path.join(base_dir, dir)):
|
|
||||||
return os.path.join(base_dir, dir)
|
|
||||||
|
|
||||||
return os.path.join(base_dir, model_type)
|
if model_type in ALTERNATE_FOLDER_NAMES:
|
||||||
|
alt_dir = ALTERNATE_FOLDER_NAMES[model_type]
|
||||||
|
alt_dir = os.path.join(base_dir, alt_dir)
|
||||||
|
if os.path.exists(alt_dir) and os.path.isdir(alt_dir):
|
||||||
|
dirs.append(alt_dir)
|
||||||
|
|
||||||
|
return dirs
|
||||||
|
|
||||||
|
|
||||||
# patch sdkit
|
# patch sdkit
|
||||||
@ -468,7 +491,7 @@ def __patched__get_actual_base_dir(model_type, download_base_dir, subdir_for_mod
|
|||||||
"Patched version that works with case-insensitive model sub-dirs"
|
"Patched version that works with case-insensitive model sub-dirs"
|
||||||
|
|
||||||
download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir
|
download_base_dir = os.path.join("~", ".cache", "sdkit") if download_base_dir is None else download_base_dir
|
||||||
download_base_dir = get_model_dir(model_type, download_base_dir) if subdir_for_model_type else download_base_dir
|
download_base_dir = get_model_dirs(model_type, download_base_dir)[0] if subdir_for_model_type else download_base_dir
|
||||||
return os.path.abspath(download_base_dir)
|
return os.path.abspath(download_base_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@ -364,7 +364,7 @@ def model_merge_internal(req: dict):
|
|||||||
|
|
||||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||||
|
|
||||||
sd_model_dir = model_manager.get_model_dir("stable-diffusion")
|
sd_model_dir = model_manager.get_model_dir("stable-diffusion")[0]
|
||||||
|
|
||||||
merge_models(
|
merge_models(
|
||||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||||
|
@ -13,7 +13,7 @@ nsfw_check_model = None
|
|||||||
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
||||||
global nsfw_check_model
|
global nsfw_check_model
|
||||||
|
|
||||||
from easydiffusion.model_manager import get_model_dir
|
from easydiffusion.model_manager import get_model_dirs
|
||||||
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
|
from sdkit.utils import base64_str_to_img, img_to_base64_str, download_file, log, hash_file_quick
|
||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
@ -21,7 +21,7 @@ def filter_nsfw(images, blur_radius: float = 75, print_log=True):
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if nsfw_check_model is None:
|
if nsfw_check_model is None:
|
||||||
model_dir = get_model_dir("nsfw-checker")
|
model_dir = get_model_dirs("nsfw-checker")[0]
|
||||||
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
|
model_path = os.path.join(model_dir, "vit-base-nsfw-detector-quantized.onnx")
|
||||||
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user