forked from extern/easydiffusion
105 lines
5.0 KiB
Diff
105 lines
5.0 KiB
Diff
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
|
|
index 27ead0e..e0ff53f 100644
|
|
--- a/ldm/models/diffusion/ddim.py
|
|
+++ b/ldm/models/diffusion/ddim.py
|
|
@@ -100,7 +100,7 @@ class DDIMSampler(object):
|
|
size = (batch_size, C, H, W)
|
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
|
|
|
- samples, intermediates = self.ddim_sampling(conditioning, size,
|
|
+ samples = self.ddim_sampling(conditioning, size,
|
|
callback=callback,
|
|
img_callback=img_callback,
|
|
quantize_denoised=quantize_x0,
|
|
@@ -117,7 +117,8 @@ class DDIMSampler(object):
|
|
dynamic_threshold=dynamic_threshold,
|
|
ucg_schedule=ucg_schedule
|
|
)
|
|
- return samples, intermediates
|
|
+ # return samples, intermediates
|
|
+ yield from samples
|
|
|
|
@torch.no_grad()
|
|
def ddim_sampling(self, cond, shape,
|
|
@@ -168,14 +169,15 @@ class DDIMSampler(object):
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
dynamic_threshold=dynamic_threshold)
|
|
img, pred_x0 = outs
|
|
- if callback: callback(i)
|
|
- if img_callback: img_callback(pred_x0, i)
|
|
+ if callback: yield from callback(i)
|
|
+ if img_callback: yield from img_callback(pred_x0, i)
|
|
|
|
if index % log_every_t == 0 or index == total_steps - 1:
|
|
intermediates['x_inter'].append(img)
|
|
intermediates['pred_x0'].append(pred_x0)
|
|
|
|
- return img, intermediates
|
|
+ # return img, intermediates
|
|
+ yield from img_callback(pred_x0, len(iterator)-1)
|
|
|
|
@torch.no_grad()
|
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
@@ -315,7 +317,7 @@ class DDIMSampler(object):
|
|
|
|
@torch.no_grad()
|
|
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
|
- use_original_steps=False, callback=None):
|
|
+ use_original_steps=False, callback=None, img_callback=None):
|
|
|
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
|
timesteps = timesteps[:t_start]
|
|
@@ -332,5 +334,6 @@ class DDIMSampler(object):
|
|
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
unconditional_conditioning=unconditional_conditioning)
|
|
- if callback: callback(i)
|
|
- return x_dec
|
|
\ No newline at end of file
|
|
+ if callback: yield from callback(i)
|
|
+ if img_callback: yield from img_callback(x_dec, i)
|
|
+ yield from img_callback(x_dec, len(iterator)-1)
|
|
\ No newline at end of file
|
|
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
|
|
index 7002a36..0951f39 100644
|
|
--- a/ldm/models/diffusion/plms.py
|
|
+++ b/ldm/models/diffusion/plms.py
|
|
@@ -96,7 +96,7 @@ class PLMSSampler(object):
|
|
size = (batch_size, C, H, W)
|
|
print(f'Data shape for PLMS sampling is {size}')
|
|
|
|
- samples, intermediates = self.plms_sampling(conditioning, size,
|
|
+ samples = self.plms_sampling(conditioning, size,
|
|
callback=callback,
|
|
img_callback=img_callback,
|
|
quantize_denoised=quantize_x0,
|
|
@@ -112,7 +112,8 @@ class PLMSSampler(object):
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
dynamic_threshold=dynamic_threshold,
|
|
)
|
|
- return samples, intermediates
|
|
+ #return samples, intermediates
|
|
+ yield from samples
|
|
|
|
@torch.no_grad()
|
|
def plms_sampling(self, cond, shape,
|
|
@@ -165,14 +166,15 @@ class PLMSSampler(object):
|
|
old_eps.append(e_t)
|
|
if len(old_eps) >= 4:
|
|
old_eps.pop(0)
|
|
- if callback: callback(i)
|
|
- if img_callback: img_callback(pred_x0, i)
|
|
+ if callback: yield from callback(i)
|
|
+ if img_callback: yield from img_callback(pred_x0, i)
|
|
|
|
if index % log_every_t == 0 or index == total_steps - 1:
|
|
intermediates['x_inter'].append(img)
|
|
intermediates['pred_x0'].append(pred_x0)
|
|
|
|
- return img, intermediates
|
|
+ # return img, intermediates
|
|
+ yield from img_callback(pred_x0, len(iterator)-1)
|
|
|
|
@torch.no_grad()
|
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|