Fix for make_schedule error in sd2

This commit is contained in:
cmdr2 2022-11-25 23:15:22 +05:30 committed by GitHub
parent b924d323d4
commit 617a8b2814
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -746,8 +746,6 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'): if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0') raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
# samples, _ = sampler.sample(S=opt.steps, # samples, _ = sampler.sample(S=opt.steps,
# conditioning=c, # conditioning=c,
@ -770,6 +768,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
elif sampler_name == 'ddim': elif sampler_name == 'ddim':
sampler = DDIMSampler(thread_data.model) sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = sampler.sample( samples_ddim = sampler.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,
@ -786,6 +787,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
sampler = sampler_name, sampler = sampler_name,
) )
else: else:
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample( samples_ddim = thread_data.model.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,