From 83cb473a45de3e1e55a47a2ca2623437177c450f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 23 Sep 2022 11:14:06 +0530 Subject: [PATCH] Fix the ddim_timesteps attribute missing error for txt2img with the ddim sampler --- ui/sd_internal/ddim_callback.patch | 57 +++++++++++------------------- ui/sd_internal/runtime.py | 3 ++ 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch index d65399a7..9852aab8 100644 --- a/ui/sd_internal/ddim_callback.patch +++ b/ui/sd_internal/ddim_callback.patch @@ -1,5 +1,5 @@ diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py -index b967b55..1c5f351 100644 +index b967b55..10a7c32 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config @@ -11,11 +11,8 @@ index b967b55..1c5f351 100644 def disabled_train(self): """Overwrite model.train with this function to make sure train/eval mode -@@ -526,41 +526,49 @@ class UNet(DDPM): - ) - +@@ -528,39 +528,46 @@ class UNet(DDPM): elif sampler == "ddim": -+ # self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - mask = mask,init_latent=x_T,use_original_steps=False) @@ -70,7 +67,7 @@ index b967b55..1c5f351 100644 @torch.no_grad() def plms_sampling(self, cond,b, img, ddim_use_original_steps=False, -@@ -599,10 +607,10 @@ class UNet(DDPM): +@@ -599,10 +606,10 @@ class UNet(DDPM): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) @@ -84,29 +81,17 @@ index b967b55..1c5f351 100644 @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -@@ -675,7 +683,9 @@ class UNet(DDPM): - def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas -+ print('making schedule') - self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) -+ print('made schedule', self.ddim_timesteps) - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - - if noise is None: -@@ -706,7 +716,10 @@ class UNet(DDPM): +@@ -706,7 +713,8 @@ class UNet(DDPM): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,init_latent=None,use_original_steps=False): + mask = None,init_latent=None,use_original_steps=False, + callback=None, img_callback=None): -+ -+ print('ddim steps', self.ddim_timesteps) timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] -@@ -730,10 +743,13 @@ class UNet(DDPM): +@@ -730,10 +738,13 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) @@ -122,7 +107,7 @@ index b967b55..1c5f351 100644 @torch.no_grad() -@@ -779,13 +795,16 @@ class UNet(DDPM): +@@ -779,13 +790,16 @@ class UNet(DDPM): @torch.no_grad() @@ -140,7 +125,7 @@ index b967b55..1c5f351 100644 s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. -@@ -807,13 +826,18 @@ class UNet(DDPM): +@@ -807,13 +821,18 @@ class UNet(DDPM): d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) @@ -161,7 +146,7 @@ index b967b55..1c5f351 100644 """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args -@@ -822,6 +846,8 @@ class UNet(DDPM): +@@ -822,6 +841,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] @@ -170,7 +155,7 @@ index b967b55..1c5f351 100644 s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): -@@ -837,17 +863,22 @@ class UNet(DDPM): +@@ -837,17 +858,22 @@ class UNet(DDPM): sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) @@ -195,7 +180,7 @@ index b967b55..1c5f351 100644 """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args -@@ -855,6 +886,8 @@ class UNet(DDPM): +@@ -855,6 +881,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] @@ -204,7 +189,7 @@ index b967b55..1c5f351 100644 s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): -@@ -876,6 +909,9 @@ class UNet(DDPM): +@@ -876,6 +904,9 @@ class UNet(DDPM): d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) @@ -214,7 +199,7 @@ index b967b55..1c5f351 100644 dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method -@@ -895,11 +931,13 @@ class UNet(DDPM): +@@ -895,11 +926,13 @@ class UNet(DDPM): d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt @@ -230,7 +215,7 @@ index b967b55..1c5f351 100644 """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args -@@ -907,6 +945,8 @@ class UNet(DDPM): +@@ -907,6 +940,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] @@ -239,7 +224,7 @@ index b967b55..1c5f351 100644 s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. -@@ -924,7 +964,7 @@ class UNet(DDPM): +@@ -924,7 +959,7 @@ class UNet(DDPM): e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) @@ -248,7 +233,7 @@ index b967b55..1c5f351 100644 d = to_d(x, sigma_hat, denoised) # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule -@@ -945,11 +985,13 @@ class UNet(DDPM): +@@ -945,11 +980,13 @@ class UNet(DDPM): d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 @@ -264,7 +249,7 @@ index b967b55..1c5f351 100644 """Ancestral sampling with DPM-Solver inspired second-order steps.""" extra_args = {} if extra_args is None else extra_args -@@ -957,6 +999,8 @@ class UNet(DDPM): +@@ -957,6 +994,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] @@ -273,7 +258,7 @@ index b967b55..1c5f351 100644 s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): -@@ -973,6 +1017,9 @@ class UNet(DDPM): +@@ -973,6 +1012,9 @@ class UNet(DDPM): sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) @@ -283,7 +268,7 @@ index b967b55..1c5f351 100644 d = to_d(x, sigmas[i], denoised) # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 -@@ -993,11 +1040,13 @@ class UNet(DDPM): +@@ -993,11 +1035,13 @@ class UNet(DDPM): d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 x = x + torch.randn_like(x) * sigma_up @@ -299,7 +284,7 @@ index b967b55..1c5f351 100644 extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) -@@ -1005,6 +1054,8 @@ class UNet(DDPM): +@@ -1005,6 +1049,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] @@ -308,7 +293,7 @@ index b967b55..1c5f351 100644 ds = [] for i in trange(len(sigmas) - 1, disable=disable): -@@ -1017,6 +1068,7 @@ class UNet(DDPM): +@@ -1017,6 +1063,7 @@ class UNet(DDPM): e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) @@ -316,7 +301,7 @@ index b967b55..1c5f351 100644 d = to_d(x, sigmas[i], denoised) ds.append(d) -@@ -1027,4 +1079,5 @@ class UNet(DDPM): +@@ -1027,4 +1074,5 @@ class UNet(DDPM): cur_order = min(i + 1, order) coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 9fa88e5b..c0382f82 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -522,6 +522,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, while torch.cuda.memory_allocated() / 1e6 >= mem: time.sleep(1) + if sampler_name == 'ddim' and not hasattr(model, 'ddim_timesteps'): + model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) + samples_ddim = model.sample( S=opt_ddim_steps, conditioning=c,