From 367e7f7065acb10655a5787179344efcd73ae06f Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 2 Dec 2022 02:28:00 -0500 Subject: [PATCH] Add dpm2 (#592) * Move cond_stage_model to the right device * Removed unused vars. * Added 'dpm2' --- ui/sd_internal/runtime.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 934d8758..8500186a 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -753,7 +753,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, if not thread_data.test_sd2: move_to_cpu(thread_data.modelCS) - if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'): + if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'): raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0') @@ -775,6 +775,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, from ldm.models.diffusion.ddim import DDIMSampler sampler = DDIMSampler(thread_data.model) sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + elif sampler_name == 'dpm2': + from ldm.models.diffusion.dpm_solver import DPMSolverSampler + sampler = DPMSolverSampler(thread_data.model) shape = [opt_C, opt_H // opt_f, opt_W // opt_f]