2022-09-13 16:29:41 +02:00
|
|
|
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
2022-09-14 13:22:03 +02:00
|
|
|
index dcf7901..4028a70 100644
|
2022-09-13 16:29:41 +02:00
|
|
|
--- a/optimizedSD/ddpm.py
|
|
|
|
+++ b/optimizedSD/ddpm.py
|
2022-09-14 13:22:03 +02:00
|
|
|
@@ -485,6 +485,7 @@ class UNet(DDPM):
|
|
|
|
log_every_t=100,
|
|
|
|
unconditional_guidance_scale=1.,
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
+ streaming_callbacks = False,
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
@@ -523,12 +524,15 @@ class UNet(DDPM):
|
|
|
|
log_every_t=log_every_t,
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
|
|
+ streaming_callbacks=streaming_callbacks
|
|
|
|
)
|
|
|
|
|
2022-09-13 16:29:41 +02:00
|
|
|
elif sampler == "ddim":
|
|
|
|
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
|
|
- mask = mask,init_latent=x_T,use_original_steps=False)
|
|
|
|
+ mask = mask,init_latent=x_T,use_original_steps=False,
|
2022-09-14 13:22:03 +02:00
|
|
|
+ callback=callback, img_callback=img_callback,
|
|
|
|
+ streaming_callbacks=streaming_callbacks)
|
2022-09-13 16:29:41 +02:00
|
|
|
|
|
|
|
# elif sampler == "euler":
|
|
|
|
# cvd = CompVisDenoiser(self.alphas_cumprod)
|
2022-09-14 13:22:03 +02:00
|
|
|
@@ -536,11 +540,15 @@ class UNet(DDPM):
|
|
|
|
# samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning,
|
|
|
|
# unconditional_guidance_scale=unconditional_guidance_scale)
|
|
|
|
|
|
|
|
+ if streaming_callbacks: # this line needs to be right after the sampling() call
|
|
|
|
+ yield from samples
|
|
|
|
+
|
|
|
|
if(self.turbo):
|
|
|
|
self.model1.to("cpu")
|
|
|
|
self.model2.to("cpu")
|
|
|
|
|
|
|
|
- return samples
|
|
|
|
+ if not streaming_callbacks:
|
|
|
|
+ return samples
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def plms_sampling(self, cond,b, img,
|
|
|
|
@@ -548,7 +556,8 @@ class UNet(DDPM):
|
|
|
|
callback=None, quantize_denoised=False,
|
|
|
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
|
|
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
|
|
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
|
|
|
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
|
|
+ streaming_callbacks=False):
|
|
|
|
|
|
|
|
device = self.betas.device
|
|
|
|
timesteps = self.ddim_timesteps
|
|
|
|
@@ -580,10 +589,22 @@ class UNet(DDPM):
|
|
|
|
old_eps.append(e_t)
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
old_eps.pop(0)
|
|
|
|
- if callback: callback(i)
|
|
|
|
- if img_callback: img_callback(pred_x0, i)
|
|
|
|
|
|
|
|
- return img
|
|
|
|
+ if callback:
|
|
|
|
+ if streaming_callbacks:
|
|
|
|
+ yield from callback(i)
|
|
|
|
+ else:
|
|
|
|
+ callback(i)
|
|
|
|
+ if img_callback:
|
|
|
|
+ if streaming_callbacks:
|
|
|
|
+ yield from img_callback(pred_x0, i)
|
|
|
|
+ else:
|
|
|
|
+ img_callback(pred_x0, i)
|
|
|
|
+
|
|
|
|
+ if streaming_callbacks and img_callback:
|
|
|
|
+ yield from img_callback(img, len(iterator)-1)
|
|
|
|
+ else:
|
|
|
|
+ return img
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
|
|
@@ -687,7 +708,9 @@ class UNet(DDPM):
|
2022-09-13 16:29:41 +02:00
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
|
|
|
- mask = None,init_latent=None,use_original_steps=False):
|
|
|
|
+ mask = None,init_latent=None,use_original_steps=False,
|
2022-09-14 13:22:03 +02:00
|
|
|
+ callback=None, img_callback=None,
|
|
|
|
+ streaming_callbacks=False):
|
2022-09-13 16:29:41 +02:00
|
|
|
|
|
|
|
timesteps = self.ddim_timesteps
|
|
|
|
timesteps = timesteps[:t_start]
|
2022-09-14 13:22:03 +02:00
|
|
|
@@ -710,11 +733,25 @@ class UNet(DDPM):
|
2022-09-13 16:29:41 +02:00
|
|
|
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
unconditional_conditioning=unconditional_conditioning)
|
|
|
|
+
|
2022-09-14 13:22:03 +02:00
|
|
|
+ if callback:
|
|
|
|
+ if streaming_callbacks:
|
|
|
|
+ yield from callback(i)
|
|
|
|
+ else:
|
|
|
|
+ callback(i)
|
|
|
|
+ if img_callback:
|
|
|
|
+ if streaming_callbacks:
|
|
|
|
+ yield from img_callback(x_dec, i)
|
|
|
|
+ else:
|
|
|
|
+ img_callback(x_dec, i)
|
2022-09-13 16:29:41 +02:00
|
|
|
|
|
|
|
if mask is not None:
|
2022-09-14 13:22:03 +02:00
|
|
|
- return x0 * mask + (1. - mask) * x_dec
|
|
|
|
+ x_dec = x0 * mask + (1. - mask) * x_dec
|
|
|
|
|
|
|
|
- return x_dec
|
|
|
|
+ if streaming_callbacks and img_callback:
|
|
|
|
+ yield from img_callback(x_dec, len(iterator)-1)
|
|
|
|
+ else:
|
|
|
|
+ return x_dec
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|