forked from extern/easydiffusion
35 lines
1.8 KiB
Diff
35 lines
1.8 KiB
Diff
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
|
index dcf7901..1f99adc 100644
|
|
--- a/optimizedSD/ddpm.py
|
|
+++ b/optimizedSD/ddpm.py
|
|
@@ -528,7 +528,8 @@ class UNet(DDPM):
|
|
elif sampler == "ddim":
|
|
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
- mask = mask,init_latent=x_T,use_original_steps=False)
|
|
+ mask = mask,init_latent=x_T,use_original_steps=False,
|
|
+ callback=callback, img_callback=img_callback)
|
|
|
|
# elif sampler == "euler":
|
|
# cvd = CompVisDenoiser(self.alphas_cumprod)
|
|
@@ -687,7 +688,8 @@ class UNet(DDPM):
|
|
|
|
@torch.no_grad()
|
|
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
|
- mask = None,init_latent=None,use_original_steps=False):
|
|
+ mask = None,init_latent=None,use_original_steps=False,
|
|
+ callback=None, img_callback=None):
|
|
|
|
timesteps = self.ddim_timesteps
|
|
timesteps = timesteps[:t_start]
|
|
@@ -710,6 +712,9 @@ class UNet(DDPM):
|
|
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=unconditional_conditioning)
|
|
+
|
|
+ if callback: callback(i)
|
|
+ if img_callback: img_callback(x_dec, i)
|
|
|
|
if mask is not None:
|
|
return x0 * mask + (1. - mask) * x_dec
|