mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-25 04:12:01 +02:00
SD2 models no longer need to be prefixed with 'sd2_' . The model loader now checks for a key that only SD2 models seem to have, to deduce which config file to use
This commit is contained in:
parent
273525e6f9
commit
bfdf487d52
@ -128,7 +128,7 @@ def load_model_ckpt():
|
||||
load_model_ckpt_sd1()
|
||||
|
||||
def load_model_ckpt_sd1():
|
||||
sd = load_model_from_config(thread_data.ckpt_file)
|
||||
sd, model_ver = load_model_from_config(thread_data.ckpt_file)
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
@ -222,12 +222,12 @@ def load_model_ckpt_sd1():
|
||||
using precision: {thread_data.precision}''')
|
||||
|
||||
def load_model_ckpt_sd2():
|
||||
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
|
||||
sd, model_ver = load_model_from_config(thread_data.ckpt_file)
|
||||
|
||||
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml"
|
||||
config = OmegaConf.load(config_file)
|
||||
verbose = False
|
||||
|
||||
sd = load_model_from_config(thread_data.ckpt_file)
|
||||
|
||||
thread_data.model = instantiate_from_config(config.model)
|
||||
m, u = thread_data.model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
@ -941,6 +941,7 @@ def chunk(it, size):
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
model_ver = 'sd1'
|
||||
|
||||
if ckpt.endswith(".safetensors"):
|
||||
print("Loading from safetensors")
|
||||
@ -952,9 +953,13 @@ def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
|
||||
if "state_dict" in pl_sd:
|
||||
return pl_sd["state_dict"]
|
||||
# check for a key that only seems to be present in SD2 models
|
||||
if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys():
|
||||
model_ver = 'sd2'
|
||||
|
||||
return pl_sd["state_dict"], model_ver
|
||||
else:
|
||||
return pl_sd
|
||||
return pl_sd, model_ver
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user