Working txt2img

This commit is contained in:
cmdr2 2022-11-25 14:29:24 +05:30
parent 02dd3e457d
commit 642c114501
4 changed files with 136 additions and 35 deletions

View File

@ -46,7 +46,7 @@ if NOT DEFINED test_sd2 set test_sd2=N
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94
@call git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
)

View File

@ -40,7 +40,7 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94
git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
fi

View File

@ -1,20 +1,84 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index 1bbdd02..cd00cc3 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule):
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
+ print('sampler 2')
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 27ead0e..6215939 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -100,7 +100,7 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM):
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None, **kwargs):
+ print('sampler 1')
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
- samples, intermediates = self.ddim_sampling(conditioning, size,
+ samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -117,7 +117,8 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
- return samples, intermediates
+ # return samples, intermediates
+ yield from samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
@@ -168,14 +169,15 @@ class DDIMSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 7002a36..0951f39 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -96,7 +96,7 @@ class PLMSSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
- samples, intermediates = self.plms_sampling(conditioning, size,
+ samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -112,7 +112,8 @@ class PLMSSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
- return samples, intermediates
+ #return samples, intermediates
+ yield from samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
@@ -165,14 +166,15 @@ class PLMSSampler(object):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,

View File

@ -749,20 +749,57 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
# samples, _ = sampler.sample(S=opt.steps,
# conditioning=c,
# batch_size=opt.n_samples,
# shape=shape,
# verbose=False,
# unconditional_guidance_scale=opt.scale,
# unconditional_conditioning=uc,
# eta=opt.ddim_eta,
# x_T=start_code)
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
if sampler_name == 'plms':
sampler = PLMSSampler(thread_data.model)
elif sampler_name == 'ddim':
sampler = DDIMSampler(thread_data.model)
samples_ddim = sampler.sample(
S=opt_ddim_steps,
conditioning=c,
batch_size=opt_n_samples,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
else:
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):