From ee19eaae6245cf868a454012b2ab0fa2b059e4c6 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 2 Dec 2022 02:27:26 -0500 Subject: [PATCH] Fix for RuntimeError, missing lines. (#591) * Move cond_stage_model to the right device * Removed unused vars. --- ui/sd_internal/runtime.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 26c116ad..934d8758 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -71,8 +71,6 @@ def thread_init(device): thread_data.device_name = None thread_data.unet_bs = 1 thread_data.precision = 'autocast' - thread_data.sampler_plms = None - thread_data.sampler_ddim = None thread_data.turbo = False thread_data.force_full_precision = False @@ -230,6 +228,8 @@ def load_model_ckpt_sd2(): thread_data.model.eval() del sd + thread_data.model.cond_stage_model.device = torch.device(thread_data.device) + if thread_data.device != "cpu" and thread_data.precision == "autocast": thread_data.model.half() thread_data.model_is_half = True @@ -768,18 +768,15 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, # x_T=start_code) if thread_data.test_sd2: - from ldm.models.diffusion.ddim import DDIMSampler - from ldm.models.diffusion.plms import PLMSSampler - - shape = [opt_C, opt_H // opt_f, opt_W // opt_f] - if sampler_name == 'plms': + from ldm.models.diffusion.plms import PLMSSampler sampler = PLMSSampler(thread_data.model) elif sampler_name == 'ddim': + from ldm.models.diffusion.ddim import DDIMSampler sampler = DDIMSampler(thread_data.model) - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + shape = [opt_C, opt_H // opt_f, opt_W // opt_f] samples_ddim, intermediates = sampler.sample( S=opt_ddim_steps,