diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index b9599834..99fb74bc 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -182,6 +182,16 @@ call WHERE uvicorn > .tmp ) ) +@>nul 2>nul call python -c "import safetensors" +@if "%ERRORLEVEL%" NEQ "0" ( + @echo. & echo SafeTensors not found. Installing + @call pip install safetensors || ( + echo "Error installing the safetensors package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" + pause + exit /b + ) +) + @>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt @if "%ERRORLEVEL%" NEQ "0" ( @echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 3054f19c..177e4f73 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -150,6 +150,13 @@ else pip install picklescan || fail "Picklescan installation failed." fi +if python -c "import safetensors" --help >/dev/null 2>&1; then + echo "SafeTensors is already installed." +else + echo "SafeTensors not found, installing." + pip install safetensors || fail "SafeTensors installation failed." +fi + mkdir -p "../models/stable-diffusion" diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 8500186a..ed1d69a7 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -29,6 +29,7 @@ from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from threading import Lock +from safetensors.torch import load_file import uuid @@ -96,7 +97,12 @@ def isSD2(): def load_model_ckpt(): if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') - if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt') + if os.path.exists(thread_data.ckpt_file + '.ckpt'): + thread_data.ckpt_file += '.ckpt' + elif os.path.exists(thread_data.ckpt_file + '.safetensors'): + thread_data.ckpt_file += '.safetensors' + elif not os.path.exists(thread_data.ckpt_file): + raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors') if not thread_data.precision: thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast' @@ -107,7 +113,7 @@ def load_model_ckpt(): if thread_data.device == 'cpu': thread_data.precision = 'full' - print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision) + print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision) if thread_data.test_sd2: load_model_ckpt_sd2() @@ -115,7 +121,7 @@ def load_model_ckpt(): load_model_ckpt_sd1() def load_model_ckpt_sd1(): - sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') + sd = load_model_from_config(thread_data.ckpt_file) li, lo = [], [] for key, value in sd.items(): sp = key.split(".") @@ -202,7 +208,7 @@ def load_model_ckpt_sd1(): thread_data.model_fs_is_half = False print(f'''loaded model - model file: {thread_data.ckpt_file}.ckpt + model file: {thread_data.ckpt_file} model.device: {model.device} modelCS.device: {modelCS.cond_stage_model.device} modelFS.device: {thread_data.modelFS.device} @@ -213,7 +219,7 @@ def load_model_ckpt_sd2(): config = OmegaConf.load(config_file) verbose = False - sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') + sd = load_model_from_config(thread_data.ckpt_file) thread_data.model = instantiate_from_config(config.model) m, u = thread_data.model.load_state_dict(sd, strict=False) @@ -239,7 +245,7 @@ def load_model_ckpt_sd2(): thread_data.model_fs_is_half = False print(f'''loaded model - model file: {thread_data.ckpt_file}.ckpt + model file: {thread_data.ckpt_file} using precision: {thread_data.precision}''') def unload_filters(): @@ -401,7 +407,12 @@ def is_model_reload_necessary(req: Request): # custom model support: # the req.use_stable_diffusion_model needs to be a valid path # to the ckpt file (without the extension). - if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') + if os.path.exists(req.use_stable_diffusion_model + '.ckpt'): + req.use_stable_diffusion_model += '.ckpt' + elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'): + req.use_stable_diffusion_model += '.safetensors' + elif not os.path.exists(req.use_stable_diffusion_model): + raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors') needs_model_reload = False if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: @@ -869,13 +880,21 @@ def chunk(it, size): def load_model_from_config(ckpt, verbose=False): print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") + + if ckpt.endswith(".safetensors"): + print("Loading from safetensors") + pl_sd = load_file(ckpt, device="cpu") + else: + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - return sd -# utils + if "state_dict" in pl_sd: + return pl_sd["state_dict"] + else: + return pl_sd + class UserInitiatedStop(Exception): pass diff --git a/ui/server.py b/ui/server.py index ffdb5ce2..c7760889 100644 --- a/ui/server.py +++ b/ui/server.py @@ -24,7 +24,7 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui' CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt'] +STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder