mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-12 09:18:51 +01:00
Fix blurry img2img
This commit is contained in:
parent
c675caf3f9
commit
e7f9db5e56
@ -1,5 +1,5 @@
|
|||||||
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
||||||
index b967b55..e06e653 100644
|
index b967b55..1c5f351 100644
|
||||||
--- a/optimizedSD/ddpm.py
|
--- a/optimizedSD/ddpm.py
|
||||||
+++ b/optimizedSD/ddpm.py
|
+++ b/optimizedSD/ddpm.py
|
||||||
@@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
|
@@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
|
||||||
@ -15,7 +15,7 @@ index b967b55..e06e653 100644
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif sampler == "ddim":
|
elif sampler == "ddim":
|
||||||
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
+ # self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
|
||||||
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
- mask = mask,init_latent=x_T,use_original_steps=False)
|
- mask = mask,init_latent=x_T,use_original_steps=False)
|
||||||
@ -84,17 +84,29 @@ index b967b55..e06e653 100644
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
@@ -706,7 +714,8 @@ class UNet(DDPM):
|
@@ -675,7 +683,9 @@ class UNet(DDPM):
|
||||||
|
def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
|
||||||
|
# fast, but does not allow for exact reconstruction
|
||||||
|
# t serves as an index to gather the correct alphas
|
||||||
|
+ print('making schedule')
|
||||||
|
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
|
||||||
|
+ print('made schedule', self.ddim_timesteps)
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
@@ -706,7 +716,10 @@ class UNet(DDPM):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
- mask = None,init_latent=None,use_original_steps=False):
|
- mask = None,init_latent=None,use_original_steps=False):
|
||||||
+ mask = None,init_latent=None,use_original_steps=False,
|
+ mask = None,init_latent=None,use_original_steps=False,
|
||||||
+ callback=None, img_callback=None):
|
+ callback=None, img_callback=None):
|
||||||
|
+
|
||||||
|
+ print('ddim steps', self.ddim_timesteps)
|
||||||
|
|
||||||
timesteps = self.ddim_timesteps
|
timesteps = self.ddim_timesteps
|
||||||
timesteps = timesteps[:t_start]
|
timesteps = timesteps[:t_start]
|
||||||
@@ -730,10 +739,13 @@ class UNet(DDPM):
|
@@ -730,10 +743,13 @@ class UNet(DDPM):
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
|
|
||||||
@ -110,7 +122,7 @@ index b967b55..e06e653 100644
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -779,13 +791,16 @@ class UNet(DDPM):
|
@@ -779,13 +795,16 @@ class UNet(DDPM):
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -128,7 +140,7 @@ index b967b55..e06e653 100644
|
|||||||
s_in = x.new_ones([x.shape[0]]).half()
|
s_in = x.new_ones([x.shape[0]]).half()
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
@@ -807,13 +822,18 @@ class UNet(DDPM):
|
@@ -807,13 +826,18 @@ class UNet(DDPM):
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigma_hat, denoised)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
@ -149,7 +161,7 @@ index b967b55..e06e653 100644
|
|||||||
"""Ancestral sampling with Euler method steps."""
|
"""Ancestral sampling with Euler method steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
@@ -822,6 +842,8 @@ class UNet(DDPM):
|
@@ -822,6 +846,8 @@ class UNet(DDPM):
|
||||||
sigmas = cvd.get_sigmas(S)
|
sigmas = cvd.get_sigmas(S)
|
||||||
x = x*sigmas[0]
|
x = x*sigmas[0]
|
||||||
|
|
||||||
@ -158,7 +170,7 @@ index b967b55..e06e653 100644
|
|||||||
s_in = x.new_ones([x.shape[0]]).half()
|
s_in = x.new_ones([x.shape[0]]).half()
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
|
||||||
@@ -837,17 +859,22 @@ class UNet(DDPM):
|
@@ -837,17 +863,22 @@ class UNet(DDPM):
|
||||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
@ -183,7 +195,7 @@ index b967b55..e06e653 100644
|
|||||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
@@ -855,6 +882,8 @@ class UNet(DDPM):
|
@@ -855,6 +886,8 @@ class UNet(DDPM):
|
||||||
sigmas = cvd.get_sigmas(S)
|
sigmas = cvd.get_sigmas(S)
|
||||||
x = x*sigmas[0]
|
x = x*sigmas[0]
|
||||||
|
|
||||||
@ -192,7 +204,7 @@ index b967b55..e06e653 100644
|
|||||||
|
|
||||||
s_in = x.new_ones([x.shape[0]]).half()
|
s_in = x.new_ones([x.shape[0]]).half()
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
@@ -876,6 +905,9 @@ class UNet(DDPM):
|
@@ -876,6 +909,9 @@ class UNet(DDPM):
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigma_hat, denoised)
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||||
@ -202,7 +214,7 @@ index b967b55..e06e653 100644
|
|||||||
dt = sigmas[i + 1] - sigma_hat
|
dt = sigmas[i + 1] - sigma_hat
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Euler method
|
# Euler method
|
||||||
@@ -895,11 +927,13 @@ class UNet(DDPM):
|
@@ -895,11 +931,13 @@ class UNet(DDPM):
|
||||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||||
d_prime = (d + d_2) / 2
|
d_prime = (d + d_2) / 2
|
||||||
x = x + d_prime * dt
|
x = x + d_prime * dt
|
||||||
@ -218,7 +230,7 @@ index b967b55..e06e653 100644
|
|||||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
@@ -907,6 +941,8 @@ class UNet(DDPM):
|
@@ -907,6 +945,8 @@ class UNet(DDPM):
|
||||||
sigmas = cvd.get_sigmas(S)
|
sigmas = cvd.get_sigmas(S)
|
||||||
x = x*sigmas[0]
|
x = x*sigmas[0]
|
||||||
|
|
||||||
@ -227,7 +239,7 @@ index b967b55..e06e653 100644
|
|||||||
s_in = x.new_ones([x.shape[0]]).half()
|
s_in = x.new_ones([x.shape[0]]).half()
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||||
@@ -924,7 +960,7 @@ class UNet(DDPM):
|
@@ -924,7 +964,7 @@ class UNet(DDPM):
|
||||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
@ -236,7 +248,7 @@ index b967b55..e06e653 100644
|
|||||||
|
|
||||||
d = to_d(x, sigma_hat, denoised)
|
d = to_d(x, sigma_hat, denoised)
|
||||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||||
@@ -945,11 +981,13 @@ class UNet(DDPM):
|
@@ -945,11 +985,13 @@ class UNet(DDPM):
|
||||||
|
|
||||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||||
x = x + d_2 * dt_2
|
x = x + d_2 * dt_2
|
||||||
@ -252,7 +264,7 @@ index b967b55..e06e653 100644
|
|||||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
|
||||||
@@ -957,6 +995,8 @@ class UNet(DDPM):
|
@@ -957,6 +999,8 @@ class UNet(DDPM):
|
||||||
sigmas = cvd.get_sigmas(S)
|
sigmas = cvd.get_sigmas(S)
|
||||||
x = x*sigmas[0]
|
x = x*sigmas[0]
|
||||||
|
|
||||||
@ -261,7 +273,7 @@ index b967b55..e06e653 100644
|
|||||||
s_in = x.new_ones([x.shape[0]]).half()
|
s_in = x.new_ones([x.shape[0]]).half()
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
|
||||||
@@ -973,6 +1013,9 @@ class UNet(DDPM):
|
@@ -973,6 +1017,9 @@ class UNet(DDPM):
|
||||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
@ -271,7 +283,7 @@ index b967b55..e06e653 100644
|
|||||||
d = to_d(x, sigmas[i], denoised)
|
d = to_d(x, sigmas[i], denoised)
|
||||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||||
@@ -993,11 +1036,13 @@ class UNet(DDPM):
|
@@ -993,11 +1040,13 @@ class UNet(DDPM):
|
||||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||||
x = x + d_2 * dt_2
|
x = x + d_2 * dt_2
|
||||||
x = x + torch.randn_like(x) * sigma_up
|
x = x + torch.randn_like(x) * sigma_up
|
||||||
@ -287,7 +299,7 @@ index b967b55..e06e653 100644
|
|||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
|
||||||
@@ -1005,6 +1050,8 @@ class UNet(DDPM):
|
@@ -1005,6 +1054,8 @@ class UNet(DDPM):
|
||||||
sigmas = cvd.get_sigmas(S)
|
sigmas = cvd.get_sigmas(S)
|
||||||
x = x*sigmas[0]
|
x = x*sigmas[0]
|
||||||
|
|
||||||
@ -296,7 +308,7 @@ index b967b55..e06e653 100644
|
|||||||
ds = []
|
ds = []
|
||||||
for i in trange(len(sigmas) - 1, disable=disable):
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
|
||||||
@@ -1017,6 +1064,7 @@ class UNet(DDPM):
|
@@ -1017,6 +1068,7 @@ class UNet(DDPM):
|
||||||
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
|
||||||
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
@ -304,7 +316,7 @@ index b967b55..e06e653 100644
|
|||||||
|
|
||||||
d = to_d(x, sigmas[i], denoised)
|
d = to_d(x, sigmas[i], denoised)
|
||||||
ds.append(d)
|
ds.append(d)
|
||||||
@@ -1027,4 +1075,5 @@ class UNet(DDPM):
|
@@ -1027,4 +1079,5 @@ class UNet(DDPM):
|
||||||
cur_order = min(i + 1, order)
|
cur_order = min(i + 1, order)
|
||||||
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
|
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
|
||||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||||
|
Loading…
Reference in New Issue
Block a user