* Move cond_stage_model to the right device

* Removed unused vars.

* Added 'dpm2'
This commit is contained in:
Marc-Andre Ferland 2022-12-02 02:28:00 -05:00 committed by GitHub
parent ee19eaae62
commit 367e7f7065
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -753,7 +753,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
if not thread_data.test_sd2: if not thread_data.test_sd2:
move_to_cpu(thread_data.modelCS) move_to_cpu(thread_data.modelCS)
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'): if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'):
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0') raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
@ -775,6 +775,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
sampler = DDIMSampler(thread_data.model) sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
elif sampler_name == 'dpm2':
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
sampler = DPMSolverSampler(thread_data.model)
shape = [opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_C, opt_H // opt_f, opt_W // opt_f]