diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py index 79058bc..a473411 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -564,12 +564,12 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, callback=callback, img_callback=img_callback) + yield from samples + if(self.turbo): self.model1.to("cpu") self.model2.to("cpu") - return samples - @torch.no_grad() def plms_sampling(self, cond,b, img, ddim_use_original_steps=False, @@ -608,10 +608,10 @@ class UNet(DDPM): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(pred_x0, i) - return img + yield from img_callback(img, len(iterator)-1) @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, @@ -740,13 +740,13 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - if img_callback: img_callback(x_dec, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x_dec, i) if mask is not None: - return x0 * mask + (1. - mask) * x_dec + x_dec = x0 * mask + (1. - mask) * x_dec - return x_dec + yield from img_callback(x_dec, len(iterator)-1) @torch.no_grad() @@ -820,12 +820,12 @@ class UNet(DDPM): d = to_d(x, sigma_hat, denoised) - if callback: callback(i) - if img_callback: img_callback(x, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x, i) dt = sigmas[i + 1] - sigma_hat # Euler method x = x + d * dt - return x + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, img_callback=None): @@ -852,14 +852,14 @@ class UNet(DDPM): denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) - if callback: callback(i) - if img_callback: img_callback(x, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x, i) d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt x = x + torch.randn_like(x) * sigma_up - return x + yield from img_callback(x, len(sigmas)-1) @@ -892,8 +892,8 @@ class UNet(DDPM): denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) d = to_d(x, sigma_hat, denoised) - if callback: callback(i) - if img_callback: img_callback(x, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x, i) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method @@ -913,7 +913,7 @@ class UNet(DDPM): d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt - return x + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() @@ -944,8 +944,8 @@ class UNet(DDPM): e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - if callback: callback(i) - if img_callback: img_callback(x, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x, i) d = to_d(x, sigma_hat, denoised) # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule @@ -966,7 +966,7 @@ class UNet(DDPM): d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 - return x + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() @@ -994,8 +994,8 @@ class UNet(DDPM): sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) - if callback: callback(i) - if img_callback: img_callback(x, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x, i) d = to_d(x, sigmas[i], denoised) # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 @@ -1016,7 +1016,7 @@ class UNet(DDPM): d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 x = x + torch.randn_like(x) * sigma_up - return x + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() @@ -1042,8 +1042,8 @@ class UNet(DDPM): e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - if callback: callback(i) - if img_callback: img_callback(x, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x, i) d = to_d(x, sigmas[i], denoised) ds.append(d) @@ -1054,4 +1054,4 @@ class UNet(DDPM): cur_order = min(i + 1, order) coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) - return x + yield from img_callback(x, len(sigmas)-1)