From dd95df8f022a51bb8e5b097b94fca34e9fd5105a Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 2 Jun 2023 16:34:29 +0530 Subject: [PATCH] Refactor the default model download code, remove check_models.py, don't check in legacy paths since that's already migrated during initialization; Download CodeFormer's model only when it's used for the first time --- scripts/check_models.py | 105 -------------------------- scripts/on_sd_start.bat | 7 -- scripts/on_sd_start.sh | 6 -- ui/easydiffusion/app.py | 45 +++++++++-- ui/easydiffusion/model_manager.py | 120 +++++++++++++++++++++--------- 5 files changed, 121 insertions(+), 162 deletions(-) delete mode 100644 scripts/check_models.py diff --git a/scripts/check_models.py b/scripts/check_models.py deleted file mode 100644 index a2186727..00000000 --- a/scripts/check_models.py +++ /dev/null @@ -1,105 +0,0 @@ -# this script runs inside the legacy "stable-diffusion" folder - -from sdkit.models import download_model, get_model_info_from_db -from sdkit.utils import hash_file_quick - -import os -import shutil -from glob import glob -import traceback - -models_base_dir = os.path.abspath(os.path.join("..", "models")) - -models_to_check = { - "stable-diffusion": [ - {"file_name": "sd-v1-4.ckpt", "model_id": "1.4"}, - ], - "gfpgan": [ - {"file_name": "GFPGANv1.4.pth", "model_id": "1.4"}, - ], - "realesrgan": [ - {"file_name": "RealESRGAN_x4plus.pth", "model_id": "x4plus"}, - {"file_name": "RealESRGAN_x4plus_anime_6B.pth", "model_id": "x4plus_anime_6"}, - ], - "vae": [ - {"file_name": "vae-ft-mse-840000-ema-pruned.ckpt", "model_id": "vae-ft-mse-840000-ema-pruned"}, - ], - "codeformer": [ - {"file_name": "codeformer.pth", "model_id": "codeformer-0.1.0"}, - ], -} -MODEL_EXTENSIONS = { # copied from easydiffusion/model_manager.py - "stable-diffusion": [".ckpt", ".safetensors"], - "vae": [".vae.pt", ".ckpt", ".safetensors"], - "hypernetwork": [".pt", ".safetensors"], - "gfpgan": [".pth"], - "realesrgan": [".pth"], - "lora": [".ckpt", ".safetensors"], - "codeformer": [".pth"], -} - - -def download_if_necessary(model_type: str, file_name: str, model_id: str): - model_path = os.path.join(models_base_dir, model_type, file_name) - expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] - - other_models_exist = any_model_exists(model_type) - known_model_exists = os.path.exists(model_path) - known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash - - if known_model_is_corrupt or (not other_models_exist and not known_model_exists): - print("> download", model_type, model_id) - download_model(model_type, model_id, download_base_dir=models_base_dir) - - -def init(): - migrate_legacy_model_location() - - for model_type, models in models_to_check.items(): - for model in models: - try: - download_if_necessary(model_type, model["file_name"], model["model_id"]) - except: - traceback.print_exc() - fail(model_type) - - print(model_type, "model(s) found.") - - -### utilities -def any_model_exists(model_type: str) -> bool: - extensions = MODEL_EXTENSIONS.get(model_type, []) - for ext in extensions: - if any(glob(f"{models_base_dir}/{model_type}/**/*{ext}", recursive=True)): - return True - - return False - - -def migrate_legacy_model_location(): - 'Move the models inside the legacy "stable-diffusion" folder, to their respective folders' - - for model_type, models in models_to_check.items(): - for model in models: - file_name = model["file_name"] - if os.path.exists(file_name): - dest_dir = os.path.join(models_base_dir, model_type) - os.makedirs(dest_dir, exist_ok=True) - shutil.move(file_name, os.path.join(dest_dir, file_name)) - - -def fail(model_name): - print( - f"""Error downloading the {model_name} model. Sorry about that, please try to: -1. Run this installer again. -2. If that doesn't fix it, please try to download the file manually. The address to download from, and the destination to save to are printed above this message. -3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB -4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues -Thanks!""" - ) - exit(1) - - -### start - -init() diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index ba205c9e..d2ef7321 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -79,13 +79,6 @@ call WHERE uvicorn > .tmp @echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt ) -@rem Download the required models -call python ..\scripts\check_models.py -if "%ERRORLEVEL%" NEQ "0" ( - pause - exit /b -) - @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @if "%ERRORLEVEL%" NEQ "0" ( @echo sd_weights_downloaded >> ..\scripts\install_status.txt diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 820c36ed..55f4da25 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -51,12 +51,6 @@ if ! command -v uvicorn &> /dev/null; then fail "UI packages not found!" fi -# Download the required models -if ! python ../scripts/check_models.py; then - read -p "Press any key to continue" - exit 1 -fi - if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then echo sd_weights_downloaded >> ../scripts/install_status.txt echo sd_install_complete >> ../scripts/install_status.txt diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 3064e151..38e3392c 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -90,8 +90,8 @@ def init(): os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True) # https://pytorch.org/docs/stable/storage.html - warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') - + warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") + load_server_plugins() update_render_threads() @@ -221,12 +221,41 @@ def open_browser(): webbrowser.open(f"http://localhost:{port}") - Console().print(Panel( - "\n" + - "[white]Easy Diffusion is ready to serve requests.\n\n" + - "A new browser tab should have been opened by now.\n" + - f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", - title="Easy Diffusion is ready", style="bold yellow on blue")) + Console().print( + Panel( + "\n" + + "[white]Easy Diffusion is ready to serve requests.\n\n" + + "A new browser tab should have been opened by now.\n" + + f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n", + title="Easy Diffusion is ready", + style="bold yellow on blue", + ) + ) + + +def fail_and_die(fail_type: str, data: str): + suggestions = [ + "Run this installer again.", + "If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB", + "If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues", + ] + + if fail_type == "model_download": + fail_label = f"Error downloading the {data} model" + suggestions.insert( + 1, + "If that doesn't fix it, please try to download the file manually. The address to download from, and the destination to save to are printed above this message.", + ) + else: + fail_label = "Error while installing Easy Diffusion" + + msg = [f"{fail_label}. Sorry about that, please try to:"] + for i, suggestion in enumerate(suggestions): + msg.append(f"{i+1}. {suggestion}") + msg.append("Thanks!") + + print("\n".join(msg)) + exit(1) def get_image_modifiers(): diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 458dae7f..c4447033 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -1,10 +1,14 @@ import os +import shutil +from glob import glob +import traceback from easydiffusion import app from easydiffusion.types import TaskData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, scan_model, unload_model +from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db +from sdkit.utils import hash_file_quick KNOWN_MODEL_TYPES = [ "stable-diffusion", @@ -25,12 +29,19 @@ MODEL_EXTENSIONS = { "codeformer": [".pth"], } DEFAULT_MODELS = { - "stable-diffusion": [ # needed to support the legacy installations - "custom-model", # only one custom model file was supported initially, creatively named 'custom-model' - "sd-v1-4", # Default fallback. + "stable-diffusion": [ + {"file_name": "sd-v1-4.ckpt", "model_id": "1.4"}, + ], + "gfpgan": [ + {"file_name": "GFPGANv1.4.pth", "model_id": "1.4"}, + ], + "realesrgan": [ + {"file_name": "RealESRGAN_x4plus.pth", "model_id": "x4plus"}, + {"file_name": "RealESRGAN_x4plus_anime_6B.pth", "model_id": "x4plus_anime_6"}, + ], + "vae": [ + {"file_name": "vae-ft-mse-840000-ema-pruned.ckpt", "model_id": "vae-ft-mse-840000-ema-pruned"}, ], - "gfpgan": ["GFPGANv1.3"], - "realesrgan": ["RealESRGAN_x4plus"], } MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] @@ -39,6 +50,8 @@ known_models = {} def init(): make_model_folders() + migrate_legacy_model_location() # if necessary + download_default_models_if_necessary() getModels() # run this once, to cache the picklescan results @@ -77,7 +90,7 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None): default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] + model_dir = os.path.join(app.MODELS_DIR, model_type) if not model_name: # When None try user configured model. # config = getConfig() if "model" in config and model_type in config["model"]: @@ -85,31 +98,25 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None): if model_name: # Check models directory - models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name) + model_path = os.path.join(model_dir, model_name) + if os.path.exists(model_path): + return model_path for model_extension in model_extensions: - if os.path.exists(models_dir_path + model_extension): - return models_dir_path + model_extension + if os.path.exists(model_path + model_extension): + return model_path + model_extension if os.path.exists(model_name + model_extension): return os.path.abspath(model_name + model_extension) - # Default locations - if model_name in default_models: - default_model_path = os.path.join(app.SD_DIR, model_name) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - return default_model_path + model_extension - # Can't find requested model, check the default paths. - for default_model in default_models: - for model_dir in model_dirs: - default_model_path = os.path.join(model_dir, default_model) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - if model_name is not None: - log.warn( - f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}" - ) - return default_model_path + model_extension + if model_type == "stable-diffusion": + for default_model in default_models: + default_model_path = os.path.join(model_dir, default_model["file_name"]) + if os.path.exists(default_model_path): + if model_name is not None: + log.warn( + f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}" + ) + return default_model_path return None @@ -136,7 +143,9 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): } if task_data.codeformer_upscale_faces and "realesrgan" not in models_to_reload.keys(): - models_to_reload["realesrgan"] = resolve_model_to_use(DEFAULT_MODELS["realesrgan"][0], "realesrgan") + models_to_reload["realesrgan"] = resolve_model_to_use( + DEFAULT_MODELS["realesrgan"][0]["file_name"], "realesrgan" + ) if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"] @@ -168,6 +177,7 @@ def resolve_model_paths(task_data: TaskData): model_type = "gfpgan" elif "codeformer" in task_data.use_face_correction.lower(): model_type = "codeformer" + download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type) if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): @@ -179,7 +189,31 @@ def fail_if_models_did_not_load(context: Context): if model_type in context.model_load_errors: e = context.model_load_errors[model_type] raise Exception(f"Could not load the {model_type} model! Reason: " + e) - # concat 'e', don't use in format string (injection attack) + + +def download_default_models_if_necessary(): + for model_type, models in DEFAULT_MODELS.items(): + for model in models: + try: + download_if_necessary(model_type, model["file_name"], model["model_id"]) + except: + traceback.print_exc() + app.fail_and_die(fail_type="model_download", data=model_type) + + print(model_type, "model(s) found.") + + +def download_if_necessary(model_type: str, file_name: str, model_id: str): + model_path = os.path.join(app.MODELS_DIR, model_type, file_name) + expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] + + other_models_exist = any_model_exists(model_type) + known_model_exists = os.path.exists(model_path) + known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash + + if known_model_is_corrupt or (not other_models_exist and not known_model_exists): + print("> download", model_type, model_id) + download_model(model_type, model_id, download_base_dir=app.MODELS_DIR) def set_vram_optimizations(context: Context): @@ -193,6 +227,26 @@ def set_vram_optimizations(context: Context): return False +def migrate_legacy_model_location(): + 'Move the models inside the legacy "stable-diffusion" folder, to their respective folders' + + for model_type, models in DEFAULT_MODELS.items(): + for model in models: + file_name = model["file_name"] + legacy_path = os.path.join(app.SD_DIR, file_name) + if os.path.exists(legacy_path): + shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name)) + + +def any_model_exists(model_type: str) -> bool: + extensions = MODEL_EXTENSIONS.get(model_type, []) + for ext in extensions: + if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)): + return True + + return False + + def set_clip_skip(context: Context, task_data: TaskData): clip_skip = task_data.clip_skip @@ -255,7 +309,7 @@ def getModels(): "vae": [], "hypernetwork": [], "lora": [], - "codeformer": [], + "codeformer": ["codeformer"], }, } @@ -312,14 +366,8 @@ def getModels(): listModels(model_type="hypernetwork") listModels(model_type="gfpgan") listModels(model_type="lora") - listModels(model_type="codeformer") if models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") - # legacy - custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt") - if os.path.exists(custom_weight_path): - models["options"]["stable-diffusion"].append("custom-model") - return models