mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-26 10:16:12 +01:00
Add dpm2 (#592)
* Move cond_stage_model to the right device * Removed unused vars. * Added 'dpm2'
This commit is contained in:
parent
ee19eaae62
commit
367e7f7065
@ -753,7 +753,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
if not thread_data.test_sd2:
|
||||
move_to_cpu(thread_data.modelCS)
|
||||
|
||||
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
|
||||
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'):
|
||||
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
|
||||
|
||||
|
||||
@ -775,6 +775,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
sampler = DDIMSampler(thread_data.model)
|
||||
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
elif sampler_name == 'dpm2':
|
||||
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
|
||||
sampler = DPMSolverSampler(thread_data.model)
|
||||
|
||||
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user