easydiffusion/ui/sd_internal/ddim_callback.patch

143 lines
6.1 KiB
Diff
Raw Normal View History

diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index b967b55..75ddd8b 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
from ldm.modules.diffusionmodules.util import make_beta_schedule
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
-from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
+from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff
def disabled_train(self):
"""Overwrite model.train with this function to make sure train/eval mode
@@ -485,6 +485,7 @@ class UNet(DDPM):
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
+ streaming_callbacks = False,
):
@@ -523,12 +524,15 @@ class UNet(DDPM):
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
+ streaming_callbacks=streaming_callbacks
)
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False)
+ mask = mask,init_latent=x_T,use_original_steps=False,
+ callback=callback, img_callback=img_callback,
+ streaming_callbacks=streaming_callbacks)
elif sampler == "euler":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
@@ -555,11 +559,15 @@ class UNet(DDPM):
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
unconditional_guidance_scale=unconditional_guidance_scale)
+ if streaming_callbacks: # this line needs to be right after the sampling() call
+ yield from samples
+
if(self.turbo):
self.model1.to("cpu")
self.model2.to("cpu")
- return samples
+ if not streaming_callbacks:
+ return samples
@torch.no_grad()
def plms_sampling(self, cond,b, img,
@@ -567,7 +575,8 @@ class UNet(DDPM):
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ streaming_callbacks=False):
device = self.betas.device
timesteps = self.ddim_timesteps
@@ -599,10 +608,21 @@ class UNet(DDPM):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
-
- return img
+ if callback:
+ if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(pred_x0, i)
+ else:
+ img_callback(pred_x0, i)
+
+ if streaming_callbacks and img_callback:
+ yield from img_callback(img, len(iterator)-1)
+ else:
+ return img
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -706,7 +726,9 @@ class UNet(DDPM):
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False):
+ mask = None,init_latent=None,use_original_steps=False,
+ callback=None, img_callback=None,
+ streaming_callbacks=False):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
@@ -730,10 +752,24 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
+ if callback:
+ if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(x_dec, i)
+ else:
+ img_callback(x_dec, i)
+
if mask is not None:
- return x0 * mask + (1. - mask) * x_dec
+ x_dec = x0 * mask + (1. - mask) * x_dec
- return x_dec
+ if streaming_callbacks and img_callback:
+ yield from img_callback(x_dec, len(iterator)-1)
+ else:
+ return x_dec
@torch.no_grad()
diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py
index abc3098..7a32ffe 100644
--- a/optimizedSD/openaimodelSplit.py
+++ b/optimizedSD/openaimodelSplit.py
@@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import (
normalization,
timestep_embedding,
)
-from splitAttention import SpatialTransformer
+from .splitAttention import SpatialTransformer
class AttentionPool2d(nn.Module):