mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-14 13:57:30 +02:00
Working txt2img
This commit is contained in:
parent
02dd3e457d
commit
642c114501
@ -46,7 +46,7 @@ if NOT DEFINED test_sd2 set test_sd2=N
|
|||||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
||||||
)
|
)
|
||||||
if "%test_sd2%" == "Y" (
|
if "%test_sd2%" == "Y" (
|
||||||
@call git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94
|
@call git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3
|
||||||
|
|
||||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
|
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
|
||||||
)
|
)
|
||||||
|
@ -40,7 +40,7 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
|
|||||||
|
|
||||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
||||||
elif [ "$test_sd2" == "Y" ]; then
|
elif [ "$test_sd2" == "Y" ]; then
|
||||||
git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94
|
git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3
|
||||||
|
|
||||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
|
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
|
||||||
fi
|
fi
|
||||||
|
@ -1,20 +1,84 @@
|
|||||||
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
|
||||||
index 1bbdd02..cd00cc3 100644
|
index 27ead0e..6215939 100644
|
||||||
--- a/optimizedSD/ddpm.py
|
--- a/ldm/models/diffusion/ddim.py
|
||||||
+++ b/optimizedSD/ddpm.py
|
+++ b/ldm/models/diffusion/ddim.py
|
||||||
@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule):
|
@@ -100,7 +100,7 @@ class DDIMSampler(object):
|
||||||
def sample(self, batch_size=16, return_intermediates=False):
|
size = (batch_size, C, H, W)
|
||||||
image_size = self.image_size
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
channels = self.channels
|
|
||||||
+ print('sampler 2')
|
|
||||||
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
|
||||||
return_intermediates=return_intermediates)
|
|
||||||
|
|
||||||
@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM):
|
- samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
+ samples = self.ddim_sampling(conditioning, size,
|
||||||
verbose=True, timesteps=None, quantize_denoised=False,
|
callback=callback,
|
||||||
mask=None, x0=None, shape=None, **kwargs):
|
img_callback=img_callback,
|
||||||
+ print('sampler 1')
|
quantize_denoised=quantize_x0,
|
||||||
if shape is None:
|
@@ -117,7 +117,8 @@ class DDIMSampler(object):
|
||||||
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
dynamic_threshold=dynamic_threshold,
|
||||||
if cond is not None:
|
ucg_schedule=ucg_schedule
|
||||||
|
)
|
||||||
|
- return samples, intermediates
|
||||||
|
+ # return samples, intermediates
|
||||||
|
+ yield from samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sampling(self, cond, shape,
|
||||||
|
@@ -168,14 +169,15 @@ class DDIMSampler(object):
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
|
img, pred_x0 = outs
|
||||||
|
- if callback: callback(i)
|
||||||
|
- if img_callback: img_callback(pred_x0, i)
|
||||||
|
+ if callback: yield from callback(i)
|
||||||
|
+ if img_callback: yield from img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
- return img, intermediates
|
||||||
|
+ # return img, intermediates
|
||||||
|
+ yield from img_callback(pred_x0, len(iterator)-1)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
|
||||||
|
index 7002a36..0951f39 100644
|
||||||
|
--- a/ldm/models/diffusion/plms.py
|
||||||
|
+++ b/ldm/models/diffusion/plms.py
|
||||||
|
@@ -96,7 +96,7 @@ class PLMSSampler(object):
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
- samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
+ samples = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
@@ -112,7 +112,8 @@ class PLMSSampler(object):
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
|
)
|
||||||
|
- return samples, intermediates
|
||||||
|
+ #return samples, intermediates
|
||||||
|
+ yield from samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
@@ -165,14 +166,15 @@ class PLMSSampler(object):
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
- if callback: callback(i)
|
||||||
|
- if img_callback: img_callback(pred_x0, i)
|
||||||
|
+ if callback: yield from callback(i)
|
||||||
|
+ if img_callback: yield from img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
- return img, intermediates
|
||||||
|
+ # return img, intermediates
|
||||||
|
+ yield from img_callback(pred_x0, len(iterator)-1)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
@ -749,6 +749,43 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
if sampler_name == 'ddim':
|
if sampler_name == 'ddim':
|
||||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
# samples, _ = sampler.sample(S=opt.steps,
|
||||||
|
# conditioning=c,
|
||||||
|
# batch_size=opt.n_samples,
|
||||||
|
# shape=shape,
|
||||||
|
# verbose=False,
|
||||||
|
# unconditional_guidance_scale=opt.scale,
|
||||||
|
# unconditional_conditioning=uc,
|
||||||
|
# eta=opt.ddim_eta,
|
||||||
|
# x_T=start_code)
|
||||||
|
|
||||||
|
if thread_data.test_sd2:
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
|
if sampler_name == 'plms':
|
||||||
|
sampler = PLMSSampler(thread_data.model)
|
||||||
|
elif sampler_name == 'ddim':
|
||||||
|
sampler = DDIMSampler(thread_data.model)
|
||||||
|
|
||||||
|
samples_ddim = sampler.sample(
|
||||||
|
S=opt_ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=opt_n_samples,
|
||||||
|
seed=opt_seed,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt_ddim_eta,
|
||||||
|
x_T=start_code,
|
||||||
|
img_callback=img_callback,
|
||||||
|
mask=mask,
|
||||||
|
sampler = sampler_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
samples_ddim = thread_data.model.sample(
|
samples_ddim = thread_data.model.sample(
|
||||||
S=opt_ddim_steps,
|
S=opt_ddim_steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user