mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
* safetensors support Add support for checkpoints in safetensors format: https://github.com/huggingface/safetensors This format shall be safer than pickle files * pip install safetensors
This commit is contained in:
parent
f701b8dc29
commit
7861c57317
@ -182,6 +182,16 @@ call WHERE uvicorn > .tmp
|
||||
)
|
||||
)
|
||||
|
||||
@>nul 2>nul call python -c "import safetensors"
|
||||
@if "%ERRORLEVEL%" NEQ "0" (
|
||||
@echo. & echo SafeTensors not found. Installing
|
||||
@call pip install safetensors || (
|
||||
echo "Error installing the safetensors package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
)
|
||||
|
||||
@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
|
||||
@if "%ERRORLEVEL%" NEQ "0" (
|
||||
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt
|
||||
|
@ -150,6 +150,13 @@ else
|
||||
pip install picklescan || fail "Picklescan installation failed."
|
||||
fi
|
||||
|
||||
if python -c "import safetensors" --help >/dev/null 2>&1; then
|
||||
echo "SafeTensors is already installed."
|
||||
else
|
||||
echo "SafeTensors not found, installing."
|
||||
pip install safetensors || fail "SafeTensors installation failed."
|
||||
fi
|
||||
|
||||
|
||||
|
||||
mkdir -p "../models/stable-diffusion"
|
||||
|
@ -29,6 +29,7 @@ from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from threading import Lock
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import uuid
|
||||
|
||||
@ -96,7 +97,12 @@ def isSD2():
|
||||
|
||||
def load_model_ckpt():
|
||||
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
||||
if os.path.exists(thread_data.ckpt_file + '.ckpt'):
|
||||
thread_data.ckpt_file += '.ckpt'
|
||||
elif os.path.exists(thread_data.ckpt_file + '.safetensors'):
|
||||
thread_data.ckpt_file += '.safetensors'
|
||||
elif not os.path.exists(thread_data.ckpt_file):
|
||||
raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors')
|
||||
|
||||
if not thread_data.precision:
|
||||
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
|
||||
@ -107,7 +113,7 @@ def load_model_ckpt():
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.precision = 'full'
|
||||
|
||||
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
|
||||
print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision)
|
||||
|
||||
if thread_data.test_sd2:
|
||||
load_model_ckpt_sd2()
|
||||
@ -115,7 +121,7 @@ def load_model_ckpt():
|
||||
load_model_ckpt_sd1()
|
||||
|
||||
def load_model_ckpt_sd1():
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
sd = load_model_from_config(thread_data.ckpt_file)
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
@ -202,7 +208,7 @@ def load_model_ckpt_sd1():
|
||||
thread_data.model_fs_is_half = False
|
||||
|
||||
print(f'''loaded model
|
||||
model file: {thread_data.ckpt_file}.ckpt
|
||||
model file: {thread_data.ckpt_file}
|
||||
model.device: {model.device}
|
||||
modelCS.device: {modelCS.cond_stage_model.device}
|
||||
modelFS.device: {thread_data.modelFS.device}
|
||||
@ -213,7 +219,7 @@ def load_model_ckpt_sd2():
|
||||
config = OmegaConf.load(config_file)
|
||||
verbose = False
|
||||
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
sd = load_model_from_config(thread_data.ckpt_file)
|
||||
|
||||
thread_data.model = instantiate_from_config(config.model)
|
||||
m, u = thread_data.model.load_state_dict(sd, strict=False)
|
||||
@ -239,7 +245,7 @@ def load_model_ckpt_sd2():
|
||||
thread_data.model_fs_is_half = False
|
||||
|
||||
print(f'''loaded model
|
||||
model file: {thread_data.ckpt_file}.ckpt
|
||||
model file: {thread_data.ckpt_file}
|
||||
using precision: {thread_data.precision}''')
|
||||
|
||||
def unload_filters():
|
||||
@ -401,7 +407,12 @@ def is_model_reload_necessary(req: Request):
|
||||
# custom model support:
|
||||
# the req.use_stable_diffusion_model needs to be a valid path
|
||||
# to the ckpt file (without the extension).
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
if os.path.exists(req.use_stable_diffusion_model + '.ckpt'):
|
||||
req.use_stable_diffusion_model += '.ckpt'
|
||||
elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'):
|
||||
req.use_stable_diffusion_model += '.safetensors'
|
||||
elif not os.path.exists(req.use_stable_diffusion_model):
|
||||
raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors')
|
||||
|
||||
needs_model_reload = False
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||
@ -869,13 +880,21 @@ def chunk(it, size):
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
|
||||
if ckpt.endswith(".safetensors"):
|
||||
print("Loading from safetensors")
|
||||
pl_sd = load_file(ckpt, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
return sd
|
||||
|
||||
# utils
|
||||
if "state_dict" in pl_sd:
|
||||
return pl_sd["state_dict"]
|
||||
else:
|
||||
return pl_sd
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
||||
|
@ -24,7 +24,7 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'
|
||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
||||
|
||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt']
|
||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
|
Loading…
Reference in New Issue
Block a user