mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-19 08:17:49 +02:00
Fix for RuntimeError, missing lines. (#591)
* Move cond_stage_model to the right device * Removed unused vars.
This commit is contained in:
parent
8eb3a3536b
commit
ee19eaae62
@ -71,8 +71,6 @@ def thread_init(device):
|
|||||||
thread_data.device_name = None
|
thread_data.device_name = None
|
||||||
thread_data.unet_bs = 1
|
thread_data.unet_bs = 1
|
||||||
thread_data.precision = 'autocast'
|
thread_data.precision = 'autocast'
|
||||||
thread_data.sampler_plms = None
|
|
||||||
thread_data.sampler_ddim = None
|
|
||||||
|
|
||||||
thread_data.turbo = False
|
thread_data.turbo = False
|
||||||
thread_data.force_full_precision = False
|
thread_data.force_full_precision = False
|
||||||
@ -230,6 +228,8 @@ def load_model_ckpt_sd2():
|
|||||||
thread_data.model.eval()
|
thread_data.model.eval()
|
||||||
del sd
|
del sd
|
||||||
|
|
||||||
|
thread_data.model.cond_stage_model.device = torch.device(thread_data.device)
|
||||||
|
|
||||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||||
thread_data.model.half()
|
thread_data.model.half()
|
||||||
thread_data.model_is_half = True
|
thread_data.model_is_half = True
|
||||||
@ -768,18 +768,15 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
# x_T=start_code)
|
# x_T=start_code)
|
||||||
|
|
||||||
if thread_data.test_sd2:
|
if thread_data.test_sd2:
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
|
|
||||||
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
|
||||||
|
|
||||||
if sampler_name == 'plms':
|
if sampler_name == 'plms':
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
sampler = PLMSSampler(thread_data.model)
|
sampler = PLMSSampler(thread_data.model)
|
||||||
elif sampler_name == 'ddim':
|
elif sampler_name == 'ddim':
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
sampler = DDIMSampler(thread_data.model)
|
sampler = DDIMSampler(thread_data.model)
|
||||||
|
|
||||||
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
samples_ddim, intermediates = sampler.sample(
|
samples_ddim, intermediates = sampler.sample(
|
||||||
S=opt_ddim_steps,
|
S=opt_ddim_steps,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user