diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py index b967b55..1c5f351 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config from ldm.modules.diffusionmodules.util import make_beta_schedule from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff +from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff def disabled_train(self): """Overwrite model.train with this function to make sure train/eval mode @@ -526,41 +526,49 @@ class UNet(DDPM): ) elif sampler == "ddim": + # self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - mask = mask,init_latent=x_T,use_original_steps=False) + mask = mask,init_latent=x_T,use_original_steps=False, + callback=callback, img_callback=img_callback) elif sampler == "euler": self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale, + img_callback=img_callback) elif sampler == "euler_a": self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale, + img_callback=img_callback) elif sampler == "dpm2": samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale, + img_callback=img_callback) elif sampler == "heun": samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale, + img_callback=img_callback) elif sampler == "dpm2_a": samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale, + img_callback=img_callback) elif sampler == "lms": samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale, + img_callback=img_callback) + + yield from samples if(self.turbo): self.model1.to("cpu") self.model2.to("cpu") - return samples - @torch.no_grad() def plms_sampling(self, cond,b, img, ddim_use_original_steps=False, @@ -599,10 +607,10 @@ class UNet(DDPM): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: yield from callback(i) + if img_callback: yield from img_callback(pred_x0, i) - return img + yield from img_callback(img, len(iterator)-1) @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, @@ -675,7 +683,9 @@ class UNet(DDPM): def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None): # fast, but does not allow for exact reconstruction # t serves as an index to gather the correct alphas + print('making schedule') self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) + print('made schedule', self.ddim_timesteps) sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) if noise is None: @@ -706,7 +716,10 @@ class UNet(DDPM): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,init_latent=None,use_original_steps=False): + mask = None,init_latent=None,use_original_steps=False, + callback=None, img_callback=None): + + print('ddim steps', self.ddim_timesteps) timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] @@ -730,10 +743,13 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) + if callback: yield from callback(i) + if img_callback: yield from img_callback(x_dec, i) + if mask is not None: - return x0 * mask + (1. - mask) * x_dec + x_dec = x0 * mask + (1. - mask) * x_dec - return x_dec + yield from img_callback(x_dec, len(iterator)-1) @torch.no_grad() @@ -779,13 +795,16 @@ class UNet(DDPM): @torch.no_grad() - def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., + img_callback=None): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args cvd = CompVisDenoiser(ac) sigmas = cvd.get_sigmas(S) x = x*sigmas[0] + print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps") + s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. @@ -807,13 +826,18 @@ class UNet(DDPM): d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + + if img_callback: yield from img_callback(x, i) + dt = sigmas[i + 1] - sigma_hat # Euler method x = x + d * dt - return x + + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() - def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None): + def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, + img_callback=None): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args @@ -822,6 +846,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] + print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps") + s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): @@ -837,17 +863,22 @@ class UNet(DDPM): sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + if img_callback: yield from img_callback(x, i) + d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt x = x + torch.randn_like(x) * sigma_up - return x + + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() - def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., + img_callback=None): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args @@ -855,6 +886,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] + print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps") + s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): @@ -876,6 +909,9 @@ class UNet(DDPM): d = to_d(x, sigma_hat, denoised) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + + if img_callback: yield from img_callback(x, i) + dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method @@ -895,11 +931,13 @@ class UNet(DDPM): d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_prime = (d + d_2) / 2 x = x + d_prime * dt - return x + + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() - def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., + img_callback=None): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args @@ -907,6 +945,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] + print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps") + s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. @@ -924,7 +964,7 @@ class UNet(DDPM): e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - + if img_callback: yield from img_callback(x, i) d = to_d(x, sigma_hat, denoised) # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule @@ -945,11 +985,13 @@ class UNet(DDPM): d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 - return x + + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() - def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None): + def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, + img_callback=None): """Ancestral sampling with DPM-Solver inspired second-order steps.""" extra_args = {} if extra_args is None else extra_args @@ -957,6 +999,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] + print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps") + s_in = x.new_ones([x.shape[0]]).half() for i in trange(len(sigmas) - 1, disable=disable): @@ -973,6 +1017,9 @@ class UNet(DDPM): sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + if img_callback: yield from img_callback(x, i) + d = to_d(x, sigmas[i], denoised) # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 @@ -993,11 +1040,13 @@ class UNet(DDPM): d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 x = x + torch.randn_like(x) * sigma_up - return x + + yield from img_callback(x, len(sigmas)-1) @torch.no_grad() - def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4): + def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4, + img_callback=None): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1005,6 +1054,8 @@ class UNet(DDPM): sigmas = cvd.get_sigmas(S) x = x*sigmas[0] + print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps") + ds = [] for i in trange(len(sigmas) - 1, disable=disable): @@ -1017,6 +1068,7 @@ class UNet(DDPM): e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + if img_callback: yield from img_callback(x, i) d = to_d(x, sigmas[i], denoised) ds.append(d) @@ -1027,4 +1079,5 @@ class UNet(DDPM): cur_order = min(i + 1, order) coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) - return x + + yield from img_callback(x, len(sigmas)-1) diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py index abc3098..7a32ffe 100644 --- a/optimizedSD/openaimodelSplit.py +++ b/optimizedSD/openaimodelSplit.py @@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import ( normalization, timestep_embedding, ) -from splitAttention import SpatialTransformer +from .splitAttention import SpatialTransformer class AttentionPool2d(nn.Module):