diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py index dcf7901..1f99adc 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -528,7 +528,8 @@ class UNet(DDPM): elif sampler == "ddim": samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - mask = mask,init_latent=x_T,use_original_steps=False) + mask = mask,init_latent=x_T,use_original_steps=False, + callback=callback, img_callback=img_callback) # elif sampler == "euler": # cvd = CompVisDenoiser(self.alphas_cumprod) @@ -687,7 +688,8 @@ class UNet(DDPM): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,init_latent=None,use_original_steps=False): + mask = None,init_latent=None,use_original_steps=False, + callback=None, img_callback=None): timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] @@ -710,6 +712,9 @@ class UNet(DDPM): x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) + + if callback: callback(i) + if img_callback: img_callback(x_dec, i) if mask is not None: return x0 * mask + (1. - mask) * x_dec