From bfdf487d524e42e5c8e1ed81d724b9751a7be008 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 7 Dec 2022 16:19:46 +0530 Subject: [PATCH] SD2 models no longer need to be prefixed with 'sd2_' . The model loader now checks for a key that only SD2 models seem to have, to deduce which config file to use --- ui/sd_internal/runtime.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 733603ac..30d19ef1 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -128,7 +128,7 @@ def load_model_ckpt(): load_model_ckpt_sd1() def load_model_ckpt_sd1(): - sd = load_model_from_config(thread_data.ckpt_file) + sd, model_ver = load_model_from_config(thread_data.ckpt_file) li, lo = [], [] for key, value in sd.items(): sp = key.split(".") @@ -222,12 +222,12 @@ def load_model_ckpt_sd1(): using precision: {thread_data.precision}''') def load_model_ckpt_sd2(): - config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml" + sd, model_ver = load_model_from_config(thread_data.ckpt_file) + + config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml" config = OmegaConf.load(config_file) verbose = False - sd = load_model_from_config(thread_data.ckpt_file) - thread_data.model = instantiate_from_config(config.model) m, u = thread_data.model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: @@ -941,6 +941,7 @@ def chunk(it, size): def load_model_from_config(ckpt, verbose=False): print(f"Loading model from {ckpt}") + model_ver = 'sd1' if ckpt.endswith(".safetensors"): print("Loading from safetensors") @@ -952,9 +953,13 @@ def load_model_from_config(ckpt, verbose=False): print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: - return pl_sd["state_dict"] + # check for a key that only seems to be present in SD2 models + if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys(): + model_ver = 'sd2' + + return pl_sd["state_dict"], model_ver else: - return pl_sd + return pl_sd, model_ver class UserInitiatedStop(Exception): pass