Fix for RuntimeError, missing lines. (#591)

* Move cond_stage_model to the right device

* Removed unused vars.
This commit is contained in:
Marc-Andre Ferland 2022-12-02 02:27:26 -05:00 committed by GitHub
parent 8eb3a3536b
commit ee19eaae62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -71,8 +71,6 @@ def thread_init(device):
thread_data.device_name = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.force_full_precision = False
@ -230,6 +228,8 @@ def load_model_ckpt_sd2():
thread_data.model.eval()
del sd
thread_data.model.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.model_is_half = True
@ -768,18 +768,15 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
# x_T=start_code)
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
if sampler_name == 'plms':
from ldm.models.diffusion.plms import PLMSSampler
sampler = PLMSSampler(thread_data.model)
elif sampler_name == 'ddim':
from ldm.models.diffusion.ddim import DDIMSampler
sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
samples_ddim, intermediates = sampler.sample(
S=opt_ddim_steps,