mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-14 02:05:21 +02:00
Refactor the default model download code, remove check_models.py, don't check in legacy paths since that's already migrated during initialization; Download CodeFormer's model only when it's used for the first time
This commit is contained in:
@ -1,105 +0,0 @@
|
||||
# this script runs inside the legacy "stable-diffusion" folder
|
||||
|
||||
from sdkit.models import download_model, get_model_info_from_db
|
||||
from sdkit.utils import hash_file_quick
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from glob import glob
|
||||
import traceback
|
||||
|
||||
models_base_dir = os.path.abspath(os.path.join("..", "models"))
|
||||
|
||||
models_to_check = {
|
||||
"stable-diffusion": [
|
||||
{"file_name": "sd-v1-4.ckpt", "model_id": "1.4"},
|
||||
],
|
||||
"gfpgan": [
|
||||
{"file_name": "GFPGANv1.4.pth", "model_id": "1.4"},
|
||||
],
|
||||
"realesrgan": [
|
||||
{"file_name": "RealESRGAN_x4plus.pth", "model_id": "x4plus"},
|
||||
{"file_name": "RealESRGAN_x4plus_anime_6B.pth", "model_id": "x4plus_anime_6"},
|
||||
],
|
||||
"vae": [
|
||||
{"file_name": "vae-ft-mse-840000-ema-pruned.ckpt", "model_id": "vae-ft-mse-840000-ema-pruned"},
|
||||
],
|
||||
"codeformer": [
|
||||
{"file_name": "codeformer.pth", "model_id": "codeformer-0.1.0"},
|
||||
],
|
||||
}
|
||||
MODEL_EXTENSIONS = { # copied from easydiffusion/model_manager.py
|
||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
||||
"hypernetwork": [".pt", ".safetensors"],
|
||||
"gfpgan": [".pth"],
|
||||
"realesrgan": [".pth"],
|
||||
"lora": [".ckpt", ".safetensors"],
|
||||
"codeformer": [".pth"],
|
||||
}
|
||||
|
||||
|
||||
def download_if_necessary(model_type: str, file_name: str, model_id: str):
|
||||
model_path = os.path.join(models_base_dir, model_type, file_name)
|
||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||
|
||||
other_models_exist = any_model_exists(model_type)
|
||||
known_model_exists = os.path.exists(model_path)
|
||||
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
|
||||
|
||||
if known_model_is_corrupt or (not other_models_exist and not known_model_exists):
|
||||
print("> download", model_type, model_id)
|
||||
download_model(model_type, model_id, download_base_dir=models_base_dir)
|
||||
|
||||
|
||||
def init():
|
||||
migrate_legacy_model_location()
|
||||
|
||||
for model_type, models in models_to_check.items():
|
||||
for model in models:
|
||||
try:
|
||||
download_if_necessary(model_type, model["file_name"], model["model_id"])
|
||||
except:
|
||||
traceback.print_exc()
|
||||
fail(model_type)
|
||||
|
||||
print(model_type, "model(s) found.")
|
||||
|
||||
|
||||
### utilities
|
||||
def any_model_exists(model_type: str) -> bool:
|
||||
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
for ext in extensions:
|
||||
if any(glob(f"{models_base_dir}/{model_type}/**/*{ext}", recursive=True)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def migrate_legacy_model_location():
|
||||
'Move the models inside the legacy "stable-diffusion" folder, to their respective folders'
|
||||
|
||||
for model_type, models in models_to_check.items():
|
||||
for model in models:
|
||||
file_name = model["file_name"]
|
||||
if os.path.exists(file_name):
|
||||
dest_dir = os.path.join(models_base_dir, model_type)
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
shutil.move(file_name, os.path.join(dest_dir, file_name))
|
||||
|
||||
|
||||
def fail(model_name):
|
||||
print(
|
||||
f"""Error downloading the {model_name} model. Sorry about that, please try to:
|
||||
1. Run this installer again.
|
||||
2. If that doesn't fix it, please try to download the file manually. The address to download from, and the destination to save to are printed above this message.
|
||||
3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB
|
||||
4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues
|
||||
Thanks!"""
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
### start
|
||||
|
||||
init()
|
@ -79,13 +79,6 @@ call WHERE uvicorn > .tmp
|
||||
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt
|
||||
)
|
||||
|
||||
@rem Download the required models
|
||||
call python ..\scripts\check_models.py
|
||||
if "%ERRORLEVEL%" NEQ "0" (
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
||||
@if "%ERRORLEVEL%" NEQ "0" (
|
||||
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
|
||||
|
@ -51,12 +51,6 @@ if ! command -v uvicorn &> /dev/null; then
|
||||
fail "UI packages not found!"
|
||||
fi
|
||||
|
||||
# Download the required models
|
||||
if ! python ../scripts/check_models.py; then
|
||||
read -p "Press any key to continue"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
||||
echo sd_install_complete >> ../scripts/install_status.txt
|
||||
|
Reference in New Issue
Block a user