SD2 models no longer need to be prefixed with 'sd2_' . The model loader now checks for a key that only SD2 models seem to have, to deduce which config file to use

This commit is contained in:
cmdr2 2022-12-07 16:19:46 +05:30
parent 273525e6f9
commit bfdf487d52

View File

@ -128,7 +128,7 @@ def load_model_ckpt():
load_model_ckpt_sd1() load_model_ckpt_sd1()
def load_model_ckpt_sd1(): def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file) sd, model_ver = load_model_from_config(thread_data.ckpt_file)
li, lo = [], [] li, lo = [], []
for key, value in sd.items(): for key, value in sd.items():
sp = key.split(".") sp = key.split(".")
@ -222,12 +222,12 @@ def load_model_ckpt_sd1():
using precision: {thread_data.precision}''') using precision: {thread_data.precision}''')
def load_model_ckpt_sd2(): def load_model_ckpt_sd2():
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml" sd, model_ver = load_model_from_config(thread_data.ckpt_file)
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml"
config = OmegaConf.load(config_file) config = OmegaConf.load(config_file)
verbose = False verbose = False
sd = load_model_from_config(thread_data.ckpt_file)
thread_data.model = instantiate_from_config(config.model) thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False) m, u = thread_data.model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose: if len(m) > 0 and verbose:
@ -941,6 +941,7 @@ def chunk(it, size):
def load_model_from_config(ckpt, verbose=False): def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
model_ver = 'sd1'
if ckpt.endswith(".safetensors"): if ckpt.endswith(".safetensors"):
print("Loading from safetensors") print("Loading from safetensors")
@ -952,9 +953,13 @@ def load_model_from_config(ckpt, verbose=False):
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
return pl_sd["state_dict"] # check for a key that only seems to be present in SD2 models
if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys():
model_ver = 'sd2'
return pl_sd["state_dict"], model_ver
else: else:
return pl_sd return pl_sd, model_ver
class UserInitiatedStop(Exception): class UserInitiatedStop(Exception):
pass pass