Fix the 'Expected all tensors to be on the same device' error

This commit is contained in:
cmdr2 2022-09-23 11:44:50 +05:30
parent 83cb473a45
commit a3de0820b3
2 changed files with 30 additions and 24 deletions

View File

@ -1,5 +1,5 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index b967b55..10a7c32 100644 index b967b55..35ef520 100644
--- a/optimizedSD/ddpm.py --- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py
@@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
@ -11,7 +11,16 @@ index b967b55..10a7c32 100644
def disabled_train(self): def disabled_train(self):
"""Overwrite model.train with this function to make sure train/eval mode """Overwrite model.train with this function to make sure train/eval mode
@@ -528,39 +528,46 @@ class UNet(DDPM): @@ -506,6 +506,8 @@ class UNet(DDPM):
x_latent = noise if x0 is None else x0
# sampling
+ if sampler in ('ddim', 'dpm2', 'heun', 'dpm2_a', 'lms') and not hasattr(self, 'ddim_timesteps'):
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
if sampler == "plms":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
@@ -528,39 +530,46 @@ class UNet(DDPM):
elif sampler == "ddim": elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
@ -67,7 +76,7 @@ index b967b55..10a7c32 100644
@torch.no_grad() @torch.no_grad()
def plms_sampling(self, cond,b, img, def plms_sampling(self, cond,b, img,
ddim_use_original_steps=False, ddim_use_original_steps=False,
@@ -599,10 +606,10 @@ class UNet(DDPM): @@ -599,10 +608,10 @@ class UNet(DDPM):
old_eps.append(e_t) old_eps.append(e_t)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)
@ -81,7 +90,7 @@ index b967b55..10a7c32 100644
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -706,7 +713,8 @@ class UNet(DDPM): @@ -706,7 +715,8 @@ class UNet(DDPM):
@torch.no_grad() @torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
@ -91,7 +100,7 @@ index b967b55..10a7c32 100644
timesteps = self.ddim_timesteps timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
@@ -730,10 +738,13 @@ class UNet(DDPM): @@ -730,10 +740,13 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning)
@ -107,7 +116,7 @@ index b967b55..10a7c32 100644
@torch.no_grad() @torch.no_grad()
@@ -779,13 +790,16 @@ class UNet(DDPM): @@ -779,13 +792,16 @@ class UNet(DDPM):
@torch.no_grad() @torch.no_grad()
@ -125,7 +134,7 @@ index b967b55..10a7c32 100644
s_in = x.new_ones([x.shape[0]]).half() s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
@@ -807,13 +821,18 @@ class UNet(DDPM): @@ -807,13 +823,18 @@ class UNet(DDPM):
d = to_d(x, sigma_hat, denoised) d = to_d(x, sigma_hat, denoised)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
@ -146,7 +155,7 @@ index b967b55..10a7c32 100644
"""Ancestral sampling with Euler method steps.""" """Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
@@ -822,6 +841,8 @@ class UNet(DDPM): @@ -822,6 +843,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S) sigmas = cvd.get_sigmas(S)
x = x*sigmas[0] x = x*sigmas[0]
@ -155,7 +164,7 @@ index b967b55..10a7c32 100644
s_in = x.new_ones([x.shape[0]]).half() s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
@@ -837,17 +858,22 @@ class UNet(DDPM): @@ -837,17 +860,22 @@ class UNet(DDPM):
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
@ -180,7 +189,7 @@ index b967b55..10a7c32 100644
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
@@ -855,6 +881,8 @@ class UNet(DDPM): @@ -855,6 +883,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S) sigmas = cvd.get_sigmas(S)
x = x*sigmas[0] x = x*sigmas[0]
@ -189,7 +198,7 @@ index b967b55..10a7c32 100644
s_in = x.new_ones([x.shape[0]]).half() s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
@@ -876,6 +904,9 @@ class UNet(DDPM): @@ -876,6 +906,9 @@ class UNet(DDPM):
d = to_d(x, sigma_hat, denoised) d = to_d(x, sigma_hat, denoised)
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
@ -199,7 +208,7 @@ index b967b55..10a7c32 100644
dt = sigmas[i + 1] - sigma_hat dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0: if sigmas[i + 1] == 0:
# Euler method # Euler method
@@ -895,11 +926,13 @@ class UNet(DDPM): @@ -895,11 +928,13 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigmas[i + 1], denoised_2) d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2 d_prime = (d + d_2) / 2
x = x + d_prime * dt x = x + d_prime * dt
@ -215,7 +224,7 @@ index b967b55..10a7c32 100644
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
@@ -907,6 +940,8 @@ class UNet(DDPM): @@ -907,6 +942,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S) sigmas = cvd.get_sigmas(S)
x = x*sigmas[0] x = x*sigmas[0]
@ -224,7 +233,7 @@ index b967b55..10a7c32 100644
s_in = x.new_ones([x.shape[0]]).half() s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
@@ -924,7 +959,7 @@ class UNet(DDPM): @@ -924,7 +961,7 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
@ -233,7 +242,7 @@ index b967b55..10a7c32 100644
d = to_d(x, sigma_hat, denoised) d = to_d(x, sigma_hat, denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
@@ -945,11 +980,13 @@ class UNet(DDPM): @@ -945,11 +982,13 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2) d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2 x = x + d_2 * dt_2
@ -249,7 +258,7 @@ index b967b55..10a7c32 100644
"""Ancestral sampling with DPM-Solver inspired second-order steps.""" """Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
@@ -957,6 +994,8 @@ class UNet(DDPM): @@ -957,6 +996,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S) sigmas = cvd.get_sigmas(S)
x = x*sigmas[0] x = x*sigmas[0]
@ -258,7 +267,7 @@ index b967b55..10a7c32 100644
s_in = x.new_ones([x.shape[0]]).half() s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
@@ -973,6 +1012,9 @@ class UNet(DDPM): @@ -973,6 +1014,9 @@ class UNet(DDPM):
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None: if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
@ -268,7 +277,7 @@ index b967b55..10a7c32 100644
d = to_d(x, sigmas[i], denoised) d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
@@ -993,11 +1035,13 @@ class UNet(DDPM): @@ -993,11 +1037,13 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2) d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2 x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up x = x + torch.randn_like(x) * sigma_up
@ -284,7 +293,7 @@ index b967b55..10a7c32 100644
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
@@ -1005,6 +1049,8 @@ class UNet(DDPM): @@ -1005,6 +1051,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S) sigmas = cvd.get_sigmas(S)
x = x*sigmas[0] x = x*sigmas[0]
@ -293,7 +302,7 @@ index b967b55..10a7c32 100644
ds = [] ds = []
for i in trange(len(sigmas) - 1, disable=disable): for i in trange(len(sigmas) - 1, disable=disable):
@@ -1017,6 +1063,7 @@ class UNet(DDPM): @@ -1017,6 +1065,7 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
@ -301,7 +310,7 @@ index b967b55..10a7c32 100644
d = to_d(x, sigmas[i], denoised) d = to_d(x, sigmas[i], denoised)
ds.append(d) ds.append(d)
@@ -1027,4 +1074,5 @@ class UNet(DDPM): @@ -1027,4 +1076,5 @@ class UNet(DDPM):
cur_order = min(i + 1, order) cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))

View File

@ -522,9 +522,6 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
while torch.cuda.memory_allocated() / 1e6 >= mem: while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1) time.sleep(1)
if sampler_name == 'ddim' and not hasattr(model, 'ddim_timesteps'):
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample( samples_ddim = model.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,