mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-02 12:06:47 +02:00
Refactor the default model download code, remove check_models.py, don't check in legacy paths since that's already migrated during initialization; Download CodeFormer's model only when it's used for the first time
This commit is contained in:
parent
0860e35d17
commit
dd95df8f02
@ -1,105 +0,0 @@
|
|||||||
# this script runs inside the legacy "stable-diffusion" folder
|
|
||||||
|
|
||||||
from sdkit.models import download_model, get_model_info_from_db
|
|
||||||
from sdkit.utils import hash_file_quick
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from glob import glob
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
models_base_dir = os.path.abspath(os.path.join("..", "models"))
|
|
||||||
|
|
||||||
models_to_check = {
|
|
||||||
"stable-diffusion": [
|
|
||||||
{"file_name": "sd-v1-4.ckpt", "model_id": "1.4"},
|
|
||||||
],
|
|
||||||
"gfpgan": [
|
|
||||||
{"file_name": "GFPGANv1.4.pth", "model_id": "1.4"},
|
|
||||||
],
|
|
||||||
"realesrgan": [
|
|
||||||
{"file_name": "RealESRGAN_x4plus.pth", "model_id": "x4plus"},
|
|
||||||
{"file_name": "RealESRGAN_x4plus_anime_6B.pth", "model_id": "x4plus_anime_6"},
|
|
||||||
],
|
|
||||||
"vae": [
|
|
||||||
{"file_name": "vae-ft-mse-840000-ema-pruned.ckpt", "model_id": "vae-ft-mse-840000-ema-pruned"},
|
|
||||||
],
|
|
||||||
"codeformer": [
|
|
||||||
{"file_name": "codeformer.pth", "model_id": "codeformer-0.1.0"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
MODEL_EXTENSIONS = { # copied from easydiffusion/model_manager.py
|
|
||||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
|
||||||
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
|
||||||
"hypernetwork": [".pt", ".safetensors"],
|
|
||||||
"gfpgan": [".pth"],
|
|
||||||
"realesrgan": [".pth"],
|
|
||||||
"lora": [".ckpt", ".safetensors"],
|
|
||||||
"codeformer": [".pth"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def download_if_necessary(model_type: str, file_name: str, model_id: str):
|
|
||||||
model_path = os.path.join(models_base_dir, model_type, file_name)
|
|
||||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
|
||||||
|
|
||||||
other_models_exist = any_model_exists(model_type)
|
|
||||||
known_model_exists = os.path.exists(model_path)
|
|
||||||
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
|
|
||||||
|
|
||||||
if known_model_is_corrupt or (not other_models_exist and not known_model_exists):
|
|
||||||
print("> download", model_type, model_id)
|
|
||||||
download_model(model_type, model_id, download_base_dir=models_base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def init():
|
|
||||||
migrate_legacy_model_location()
|
|
||||||
|
|
||||||
for model_type, models in models_to_check.items():
|
|
||||||
for model in models:
|
|
||||||
try:
|
|
||||||
download_if_necessary(model_type, model["file_name"], model["model_id"])
|
|
||||||
except:
|
|
||||||
traceback.print_exc()
|
|
||||||
fail(model_type)
|
|
||||||
|
|
||||||
print(model_type, "model(s) found.")
|
|
||||||
|
|
||||||
|
|
||||||
### utilities
|
|
||||||
def any_model_exists(model_type: str) -> bool:
|
|
||||||
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
|
||||||
for ext in extensions:
|
|
||||||
if any(glob(f"{models_base_dir}/{model_type}/**/*{ext}", recursive=True)):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_legacy_model_location():
|
|
||||||
'Move the models inside the legacy "stable-diffusion" folder, to their respective folders'
|
|
||||||
|
|
||||||
for model_type, models in models_to_check.items():
|
|
||||||
for model in models:
|
|
||||||
file_name = model["file_name"]
|
|
||||||
if os.path.exists(file_name):
|
|
||||||
dest_dir = os.path.join(models_base_dir, model_type)
|
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
|
||||||
shutil.move(file_name, os.path.join(dest_dir, file_name))
|
|
||||||
|
|
||||||
|
|
||||||
def fail(model_name):
|
|
||||||
print(
|
|
||||||
f"""Error downloading the {model_name} model. Sorry about that, please try to:
|
|
||||||
1. Run this installer again.
|
|
||||||
2. If that doesn't fix it, please try to download the file manually. The address to download from, and the destination to save to are printed above this message.
|
|
||||||
3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB
|
|
||||||
4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues
|
|
||||||
Thanks!"""
|
|
||||||
)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
### start
|
|
||||||
|
|
||||||
init()
|
|
@ -79,13 +79,6 @@ call WHERE uvicorn > .tmp
|
|||||||
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt
|
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt
|
||||||
)
|
)
|
||||||
|
|
||||||
@rem Download the required models
|
|
||||||
call python ..\scripts\check_models.py
|
|
||||||
if "%ERRORLEVEL%" NEQ "0" (
|
|
||||||
pause
|
|
||||||
exit /b
|
|
||||||
)
|
|
||||||
|
|
||||||
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
||||||
@if "%ERRORLEVEL%" NEQ "0" (
|
@if "%ERRORLEVEL%" NEQ "0" (
|
||||||
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
|
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
|
||||||
|
@ -51,12 +51,6 @@ if ! command -v uvicorn &> /dev/null; then
|
|||||||
fail "UI packages not found!"
|
fail "UI packages not found!"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Download the required models
|
|
||||||
if ! python ../scripts/check_models.py; then
|
|
||||||
read -p "Press any key to continue"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||||
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
||||||
echo sd_install_complete >> ../scripts/install_status.txt
|
echo sd_install_complete >> ../scripts/install_status.txt
|
||||||
|
@ -90,8 +90,8 @@ def init():
|
|||||||
os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True)
|
os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True)
|
||||||
|
|
||||||
# https://pytorch.org/docs/stable/storage.html
|
# https://pytorch.org/docs/stable/storage.html
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
||||||
|
|
||||||
load_server_plugins()
|
load_server_plugins()
|
||||||
|
|
||||||
update_render_threads()
|
update_render_threads()
|
||||||
@ -221,12 +221,41 @@ def open_browser():
|
|||||||
|
|
||||||
webbrowser.open(f"http://localhost:{port}")
|
webbrowser.open(f"http://localhost:{port}")
|
||||||
|
|
||||||
Console().print(Panel(
|
Console().print(
|
||||||
"\n" +
|
Panel(
|
||||||
"[white]Easy Diffusion is ready to serve requests.\n\n" +
|
"\n"
|
||||||
"A new browser tab should have been opened by now.\n" +
|
+ "[white]Easy Diffusion is ready to serve requests.\n\n"
|
||||||
f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n",
|
+ "A new browser tab should have been opened by now.\n"
|
||||||
title="Easy Diffusion is ready", style="bold yellow on blue"))
|
+ f"If not, please open your web browser and navigate to [bold yellow underline]http://localhost:{port}/\n",
|
||||||
|
title="Easy Diffusion is ready",
|
||||||
|
style="bold yellow on blue",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fail_and_die(fail_type: str, data: str):
|
||||||
|
suggestions = [
|
||||||
|
"Run this installer again.",
|
||||||
|
"If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB",
|
||||||
|
"If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues",
|
||||||
|
]
|
||||||
|
|
||||||
|
if fail_type == "model_download":
|
||||||
|
fail_label = f"Error downloading the {data} model"
|
||||||
|
suggestions.insert(
|
||||||
|
1,
|
||||||
|
"If that doesn't fix it, please try to download the file manually. The address to download from, and the destination to save to are printed above this message.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fail_label = "Error while installing Easy Diffusion"
|
||||||
|
|
||||||
|
msg = [f"{fail_label}. Sorry about that, please try to:"]
|
||||||
|
for i, suggestion in enumerate(suggestions):
|
||||||
|
msg.append(f"{i+1}. {suggestion}")
|
||||||
|
msg.append("Thanks!")
|
||||||
|
|
||||||
|
print("\n".join(msg))
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
def get_image_modifiers():
|
def get_image_modifiers():
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
from glob import glob
|
||||||
|
import traceback
|
||||||
|
|
||||||
from easydiffusion import app
|
from easydiffusion import app
|
||||||
from easydiffusion.types import TaskData
|
from easydiffusion.types import TaskData
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
from sdkit import Context
|
from sdkit import Context
|
||||||
from sdkit.models import load_model, scan_model, unload_model
|
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
|
||||||
|
from sdkit.utils import hash_file_quick
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = [
|
KNOWN_MODEL_TYPES = [
|
||||||
"stable-diffusion",
|
"stable-diffusion",
|
||||||
@ -25,12 +29,19 @@ MODEL_EXTENSIONS = {
|
|||||||
"codeformer": [".pth"],
|
"codeformer": [".pth"],
|
||||||
}
|
}
|
||||||
DEFAULT_MODELS = {
|
DEFAULT_MODELS = {
|
||||||
"stable-diffusion": [ # needed to support the legacy installations
|
"stable-diffusion": [
|
||||||
"custom-model", # only one custom model file was supported initially, creatively named 'custom-model'
|
{"file_name": "sd-v1-4.ckpt", "model_id": "1.4"},
|
||||||
"sd-v1-4", # Default fallback.
|
],
|
||||||
|
"gfpgan": [
|
||||||
|
{"file_name": "GFPGANv1.4.pth", "model_id": "1.4"},
|
||||||
|
],
|
||||||
|
"realesrgan": [
|
||||||
|
{"file_name": "RealESRGAN_x4plus.pth", "model_id": "x4plus"},
|
||||||
|
{"file_name": "RealESRGAN_x4plus_anime_6B.pth", "model_id": "x4plus_anime_6"},
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
{"file_name": "vae-ft-mse-840000-ema-pruned.ckpt", "model_id": "vae-ft-mse-840000-ema-pruned"},
|
||||||
],
|
],
|
||||||
"gfpgan": ["GFPGANv1.3"],
|
|
||||||
"realesrgan": ["RealESRGAN_x4plus"],
|
|
||||||
}
|
}
|
||||||
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"]
|
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"]
|
||||||
|
|
||||||
@ -39,6 +50,8 @@ known_models = {}
|
|||||||
|
|
||||||
def init():
|
def init():
|
||||||
make_model_folders()
|
make_model_folders()
|
||||||
|
migrate_legacy_model_location() # if necessary
|
||||||
|
download_default_models_if_necessary()
|
||||||
getModels() # run this once, to cache the picklescan results
|
getModels() # run this once, to cache the picklescan results
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +90,7 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None):
|
|||||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
|
||||||
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
model_dir = os.path.join(app.MODELS_DIR, model_type)
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
# config = getConfig()
|
# config = getConfig()
|
||||||
if "model" in config and model_type in config["model"]:
|
if "model" in config and model_type in config["model"]:
|
||||||
@ -85,31 +98,25 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None):
|
|||||||
|
|
||||||
if model_name:
|
if model_name:
|
||||||
# Check models directory
|
# Check models directory
|
||||||
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
|
model_path = os.path.join(model_dir, model_name)
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
return model_path
|
||||||
for model_extension in model_extensions:
|
for model_extension in model_extensions:
|
||||||
if os.path.exists(models_dir_path + model_extension):
|
if os.path.exists(model_path + model_extension):
|
||||||
return models_dir_path + model_extension
|
return model_path + model_extension
|
||||||
if os.path.exists(model_name + model_extension):
|
if os.path.exists(model_name + model_extension):
|
||||||
return os.path.abspath(model_name + model_extension)
|
return os.path.abspath(model_name + model_extension)
|
||||||
|
|
||||||
# Default locations
|
|
||||||
if model_name in default_models:
|
|
||||||
default_model_path = os.path.join(app.SD_DIR, model_name)
|
|
||||||
for model_extension in model_extensions:
|
|
||||||
if os.path.exists(default_model_path + model_extension):
|
|
||||||
return default_model_path + model_extension
|
|
||||||
|
|
||||||
# Can't find requested model, check the default paths.
|
# Can't find requested model, check the default paths.
|
||||||
for default_model in default_models:
|
if model_type == "stable-diffusion":
|
||||||
for model_dir in model_dirs:
|
for default_model in default_models:
|
||||||
default_model_path = os.path.join(model_dir, default_model)
|
default_model_path = os.path.join(model_dir, default_model["file_name"])
|
||||||
for model_extension in model_extensions:
|
if os.path.exists(default_model_path):
|
||||||
if os.path.exists(default_model_path + model_extension):
|
if model_name is not None:
|
||||||
if model_name is not None:
|
log.warn(
|
||||||
log.warn(
|
f"Could not find the configured custom model {model_name}. Using the default one: {default_model_path}"
|
||||||
f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}"
|
)
|
||||||
)
|
return default_model_path
|
||||||
return default_model_path + model_extension
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -136,7 +143,9 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if task_data.codeformer_upscale_faces and "realesrgan" not in models_to_reload.keys():
|
if task_data.codeformer_upscale_faces and "realesrgan" not in models_to_reload.keys():
|
||||||
models_to_reload["realesrgan"] = resolve_model_to_use(DEFAULT_MODELS["realesrgan"][0], "realesrgan")
|
models_to_reload["realesrgan"] = resolve_model_to_use(
|
||||||
|
DEFAULT_MODELS["realesrgan"][0]["file_name"], "realesrgan"
|
||||||
|
)
|
||||||
|
|
||||||
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD
|
if set_vram_optimizations(context) or set_clip_skip(context, task_data): # reload SD
|
||||||
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
||||||
@ -168,6 +177,7 @@ def resolve_model_paths(task_data: TaskData):
|
|||||||
model_type = "gfpgan"
|
model_type = "gfpgan"
|
||||||
elif "codeformer" in task_data.use_face_correction.lower():
|
elif "codeformer" in task_data.use_face_correction.lower():
|
||||||
model_type = "codeformer"
|
model_type = "codeformer"
|
||||||
|
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
|
||||||
|
|
||||||
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type)
|
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, model_type)
|
||||||
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
||||||
@ -179,7 +189,31 @@ def fail_if_models_did_not_load(context: Context):
|
|||||||
if model_type in context.model_load_errors:
|
if model_type in context.model_load_errors:
|
||||||
e = context.model_load_errors[model_type]
|
e = context.model_load_errors[model_type]
|
||||||
raise Exception(f"Could not load the {model_type} model! Reason: " + e)
|
raise Exception(f"Could not load the {model_type} model! Reason: " + e)
|
||||||
# concat 'e', don't use in format string (injection attack)
|
|
||||||
|
|
||||||
|
def download_default_models_if_necessary():
|
||||||
|
for model_type, models in DEFAULT_MODELS.items():
|
||||||
|
for model in models:
|
||||||
|
try:
|
||||||
|
download_if_necessary(model_type, model["file_name"], model["model_id"])
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
app.fail_and_die(fail_type="model_download", data=model_type)
|
||||||
|
|
||||||
|
print(model_type, "model(s) found.")
|
||||||
|
|
||||||
|
|
||||||
|
def download_if_necessary(model_type: str, file_name: str, model_id: str):
|
||||||
|
model_path = os.path.join(app.MODELS_DIR, model_type, file_name)
|
||||||
|
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||||
|
|
||||||
|
other_models_exist = any_model_exists(model_type)
|
||||||
|
known_model_exists = os.path.exists(model_path)
|
||||||
|
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
|
||||||
|
|
||||||
|
if known_model_is_corrupt or (not other_models_exist and not known_model_exists):
|
||||||
|
print("> download", model_type, model_id)
|
||||||
|
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR)
|
||||||
|
|
||||||
|
|
||||||
def set_vram_optimizations(context: Context):
|
def set_vram_optimizations(context: Context):
|
||||||
@ -193,6 +227,26 @@ def set_vram_optimizations(context: Context):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_legacy_model_location():
|
||||||
|
'Move the models inside the legacy "stable-diffusion" folder, to their respective folders'
|
||||||
|
|
||||||
|
for model_type, models in DEFAULT_MODELS.items():
|
||||||
|
for model in models:
|
||||||
|
file_name = model["file_name"]
|
||||||
|
legacy_path = os.path.join(app.SD_DIR, file_name)
|
||||||
|
if os.path.exists(legacy_path):
|
||||||
|
shutil.move(legacy_path, os.path.join(app.MODELS_DIR, model_type, file_name))
|
||||||
|
|
||||||
|
|
||||||
|
def any_model_exists(model_type: str) -> bool:
|
||||||
|
extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
|
for ext in extensions:
|
||||||
|
if any(glob(f"{app.MODELS_DIR}/{model_type}/**/*{ext}", recursive=True)):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def set_clip_skip(context: Context, task_data: TaskData):
|
def set_clip_skip(context: Context, task_data: TaskData):
|
||||||
clip_skip = task_data.clip_skip
|
clip_skip = task_data.clip_skip
|
||||||
|
|
||||||
@ -255,7 +309,7 @@ def getModels():
|
|||||||
"vae": [],
|
"vae": [],
|
||||||
"hypernetwork": [],
|
"hypernetwork": [],
|
||||||
"lora": [],
|
"lora": [],
|
||||||
"codeformer": [],
|
"codeformer": ["codeformer"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,14 +366,8 @@ def getModels():
|
|||||||
listModels(model_type="hypernetwork")
|
listModels(model_type="hypernetwork")
|
||||||
listModels(model_type="gfpgan")
|
listModels(model_type="gfpgan")
|
||||||
listModels(model_type="lora")
|
listModels(model_type="lora")
|
||||||
listModels(model_type="codeformer")
|
|
||||||
|
|
||||||
if models_scanned > 0:
|
if models_scanned > 0:
|
||||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||||
|
|
||||||
# legacy
|
|
||||||
custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt")
|
|
||||||
if os.path.exists(custom_weight_path):
|
|
||||||
models["options"]["stable-diffusion"].append("custom-model")
|
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
Loading…
Reference in New Issue
Block a user