diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 892135be..17466d77 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -746,8 +746,6 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'): raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0') - if sampler_name == 'ddim': - thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) # samples, _ = sampler.sample(S=opt.steps, # conditioning=c, @@ -770,6 +768,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, elif sampler_name == 'ddim': sampler = DDIMSampler(thread_data.model) + sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + + samples_ddim = sampler.sample( S=opt_ddim_steps, conditioning=c, @@ -786,6 +787,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, sampler = sampler_name, ) else: + if sampler_name == 'ddim': + thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + samples_ddim = thread_data.model.sample( S=opt_ddim_steps, conditioning=c,