diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py index b967b55..75ddd8b 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config from ldm.modules.diffusionmodules.util import make_beta_schedule from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff +from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff def disabled_train(self): """Overwrite model.train with this function to make sure train/eval mode @@ -485,6 +485,7 @@ class UNet(DDPM): log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, + streaming_callbacks = False, ): @@ -523,12 +524,15 @@ class UNet(DDPM): log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + streaming_callbacks=streaming_callbacks ) elif sampler == "ddim": samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - mask = mask,init_latent=x_T,use_original_steps=False) + mask = mask,init_latent=x_T,use_original_steps=False, + callback=callback, img_callback=img_callback, + streaming_callbacks=streaming_callbacks) elif sampler == "euler": self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) @@ -555,11 +559,15 @@ class UNet(DDPM): samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, unconditional_guidance_scale=unconditional_guidance_scale) + if streaming_callbacks: # this line needs to be right after the sampling() call + yield from samples + if(self.turbo): self.model1.to("cpu") self.model2.to("cpu") - return samples + if not streaming_callbacks: + return samples @torch.no_grad() def plms_sampling(self, cond,b, img, @@ -567,7 +575,8 @@ class UNet(DDPM): callback=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + unconditional_guidance_scale=1., unconditional_conditioning=None, + streaming_callbacks=False): device = self.betas.device timesteps = self.ddim_timesteps @@ -599,10 +608,21 @@ class UNet(DDPM): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - return img + if callback: + if streaming_callbacks: + yield from callback(i) + else: + callback(i) + if img_callback: + if streaming_callbacks: + yield from img_callback(pred_x0, i) + else: + img_callback(pred_x0, i) + + if streaming_callbacks and img_callback: + yield from img_callback(img, len(iterator)-1) + else: + return img @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, @@ -706,7 +726,9 @@ class UNet(DDPM): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,init_latent=None,use_original_steps=False): + mask = None,init_latent=None,use_original_steps=False, + callback=None, img_callback=None, + streaming_callbacks=False): timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] @@ -730,10 +752,24 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) + if callback: + if streaming_callbacks: + yield from callback(i) + else: + callback(i) + if img_callback: + if streaming_callbacks: + yield from img_callback(x_dec, i) + else: + img_callback(x_dec, i) + if mask is not None: - return x0 * mask + (1. - mask) * x_dec + x_dec = x0 * mask + (1. - mask) * x_dec - return x_dec + if streaming_callbacks and img_callback: + yield from img_callback(x_dec, len(iterator)-1) + else: + return x_dec @torch.no_grad() diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py index abc3098..7a32ffe 100644 --- a/optimizedSD/openaimodelSplit.py +++ b/optimizedSD/openaimodelSplit.py @@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import ( normalization, timestep_embedding, ) -from splitAttention import SpatialTransformer +from .splitAttention import SpatialTransformer class AttentionPool2d(nn.Module):