diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 27ead0e..6215939 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -100,7 +100,7 @@ class DDIMSampler(object): size = (batch_size, C, H, W) print(f'Data shape for DDIM sampling is {size}, eta {eta}') - samples, intermediates = self.ddim_sampling(conditioning, size, + samples = self.ddim_sampling(conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, @@ -117,7 +117,8 @@ class DDIMSampler(object): dynamic_threshold=dynamic_threshold, ucg_schedule=ucg_schedule ) - return samples, intermediates + # return samples, intermediates + yield from samples @torch.no_grad() def ddim_sampling(self, cond, shape, @@ -168,14 +169,15 @@ class DDIMSampler(object): unconditional_conditioning=unconditional_conditioning, dynamic_threshold=dynamic_threshold) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) - return img, intermediates + # return img, intermediates + yield from img_callback(pred_x0, len(iterator)-1) @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 7002a36..0951f39 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -96,7 +96,7 @@ class PLMSSampler(object): size = (batch_size, C, H, W) print(f'Data shape for PLMS sampling is {size}') - samples, intermediates = self.plms_sampling(conditioning, size, + samples = self.plms_sampling(conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, @@ -112,7 +112,8 @@ class PLMSSampler(object): unconditional_conditioning=unconditional_conditioning, dynamic_threshold=dynamic_threshold, ) - return samples, intermediates + #return samples, intermediates + yield from samples @torch.no_grad() def plms_sampling(self, cond, shape, @@ -165,14 +166,15 @@ class PLMSSampler(object): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) - return img, intermediates + # return img, intermediates + yield from img_callback(pred_x0, len(iterator)-1) @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,