Safetensor support (Fixes #599) (#608)

* safetensors support
Add support for checkpoints in safetensors format: https://github.com/huggingface/safetensors

This format shall be safer than pickle files

* pip install safetensors
This commit is contained in:
JeLuF 2022-12-05 06:29:48 +01:00 committed by GitHub
parent f701b8dc29
commit 7861c57317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 12 deletions

View File

@ -182,6 +182,16 @@ call WHERE uvicorn > .tmp
)
)
@>nul 2>nul call python -c "import safetensors"
@if "%ERRORLEVEL%" NEQ "0" (
@echo. & echo SafeTensors not found. Installing
@call pip install safetensors || (
echo "Error installing the safetensors package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt

View File

@ -150,6 +150,13 @@ else
pip install picklescan || fail "Picklescan installation failed."
fi
if python -c "import safetensors" --help >/dev/null 2>&1; then
echo "SafeTensors is already installed."
else
echo "SafeTensors not found, installing."
pip install safetensors || fail "SafeTensors installation failed."
fi
mkdir -p "../models/stable-diffusion"

View File

@ -29,6 +29,7 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from threading import Lock
from safetensors.torch import load_file
import uuid
@ -96,7 +97,12 @@ def isSD2():
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
if os.path.exists(thread_data.ckpt_file + '.ckpt'):
thread_data.ckpt_file += '.ckpt'
elif os.path.exists(thread_data.ckpt_file + '.safetensors'):
thread_data.ckpt_file += '.safetensors'
elif not os.path.exists(thread_data.ckpt_file):
raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors')
if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
@ -107,7 +113,7 @@ def load_model_ckpt():
if thread_data.device == 'cpu':
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision)
if thread_data.test_sd2:
load_model_ckpt_sd2()
@ -115,7 +121,7 @@ def load_model_ckpt():
load_model_ckpt_sd1()
def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
sd = load_model_from_config(thread_data.ckpt_file)
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
@ -202,7 +208,7 @@ def load_model_ckpt_sd1():
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}.ckpt
model file: {thread_data.ckpt_file}
model.device: {model.device}
modelCS.device: {modelCS.cond_stage_model.device}
modelFS.device: {thread_data.modelFS.device}
@ -213,7 +219,7 @@ def load_model_ckpt_sd2():
config = OmegaConf.load(config_file)
verbose = False
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
sd = load_model_from_config(thread_data.ckpt_file)
thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False)
@ -239,7 +245,7 @@ def load_model_ckpt_sd2():
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}.ckpt
model file: {thread_data.ckpt_file}
using precision: {thread_data.precision}''')
def unload_filters():
@ -401,7 +407,12 @@ def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
if os.path.exists(req.use_stable_diffusion_model + '.ckpt'):
req.use_stable_diffusion_model += '.ckpt'
elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'):
req.use_stable_diffusion_model += '.safetensors'
elif not os.path.exists(req.use_stable_diffusion_model):
raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
@ -869,13 +880,21 @@ def chunk(it, size):
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if ckpt.endswith(".safetensors"):
print("Loading from safetensors")
pl_sd = load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
return sd
# utils
if "state_dict" in pl_sd:
return pl_sd["state_dict"]
else:
return pl_sd
class UserInitiatedStop(Exception):
pass

View File

@ -24,7 +24,7 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt']
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder