mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-26 18:25:29 +01:00
Fix for RuntimeError, missing lines. (#591)
* Move cond_stage_model to the right device * Removed unused vars.
This commit is contained in:
parent
8eb3a3536b
commit
ee19eaae62
@ -71,8 +71,6 @@ def thread_init(device):
|
||||
thread_data.device_name = None
|
||||
thread_data.unet_bs = 1
|
||||
thread_data.precision = 'autocast'
|
||||
thread_data.sampler_plms = None
|
||||
thread_data.sampler_ddim = None
|
||||
|
||||
thread_data.turbo = False
|
||||
thread_data.force_full_precision = False
|
||||
@ -230,6 +228,8 @@ def load_model_ckpt_sd2():
|
||||
thread_data.model.eval()
|
||||
del sd
|
||||
|
||||
thread_data.model.cond_stage_model.device = torch.device(thread_data.device)
|
||||
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
thread_data.model.half()
|
||||
thread_data.model_is_half = True
|
||||
@ -768,18 +768,15 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
# x_T=start_code)
|
||||
|
||||
if thread_data.test_sd2:
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
|
||||
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if sampler_name == 'plms':
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
sampler = PLMSSampler(thread_data.model)
|
||||
elif sampler_name == 'ddim':
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
sampler = DDIMSampler(thread_data.model)
|
||||
|
||||
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
||||
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
samples_ddim, intermediates = sampler.sample(
|
||||
S=opt_ddim_steps,
|
||||
|
Loading…
Reference in New Issue
Block a user