diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 57337107..03096642 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -46,7 +46,7 @@ if NOT DEFINED test_sd2 set test_sd2=N @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch ) if "%test_sd2%" == "Y" ( - @call git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94 + @call git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3 @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch ) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index a61c5e1c..460d58a7 100644 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -40,7 +40,7 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" elif [ "$test_sd2" == "Y" ]; then - git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94 + git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3 git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed" fi diff --git a/ui/sd_internal/ddim_callback_sd2.patch b/ui/sd_internal/ddim_callback_sd2.patch index 00700d00..cadf81ca 100644 --- a/ui/sd_internal/ddim_callback_sd2.patch +++ b/ui/sd_internal/ddim_callback_sd2.patch @@ -1,20 +1,84 @@ -diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py -index 1bbdd02..cd00cc3 100644 ---- a/optimizedSD/ddpm.py -+++ b/optimizedSD/ddpm.py -@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule): - def sample(self, batch_size=16, return_intermediates=False): - image_size = self.image_size - channels = self.channels -+ print('sampler 2') - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) +diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py +index 27ead0e..6215939 100644 +--- a/ldm/models/diffusion/ddim.py ++++ b/ldm/models/diffusion/ddim.py +@@ -100,7 +100,7 @@ class DDIMSampler(object): + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') -@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM): - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None, **kwargs): -+ print('sampler 1') - if shape is None: - shape = (batch_size, self.channels, self.image_size, self.image_size) - if cond is not None: +- samples, intermediates = self.ddim_sampling(conditioning, size, ++ samples = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, +@@ -117,7 +117,8 @@ class DDIMSampler(object): + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) +- return samples, intermediates ++ # return samples, intermediates ++ yield from samples + + @torch.no_grad() + def ddim_sampling(self, cond, shape, +@@ -168,14 +169,15 @@ class DDIMSampler(object): + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs +- if callback: callback(i) +- if img_callback: img_callback(pred_x0, i) ++ if callback: yield from callback(i) ++ if img_callback: yield from img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + +- return img, intermediates ++ # return img, intermediates ++ yield from img_callback(pred_x0, len(iterator)-1) + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, +diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py +index 7002a36..0951f39 100644 +--- a/ldm/models/diffusion/plms.py ++++ b/ldm/models/diffusion/plms.py +@@ -96,7 +96,7 @@ class PLMSSampler(object): + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + +- samples, intermediates = self.plms_sampling(conditioning, size, ++ samples = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, +@@ -112,7 +112,8 @@ class PLMSSampler(object): + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) +- return samples, intermediates ++ #return samples, intermediates ++ yield from samples + + @torch.no_grad() + def plms_sampling(self, cond, shape, +@@ -165,14 +166,15 @@ class PLMSSampler(object): + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) +- if callback: callback(i) +- if img_callback: img_callback(pred_x0, i) ++ if callback: yield from callback(i) ++ if img_callback: yield from img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + +- return img, intermediates ++ # return img, intermediates ++ yield from img_callback(pred_x0, len(iterator)-1) + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index b5b61f65..c168ddcf 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -749,20 +749,57 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, if sampler_name == 'ddim': thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - samples_ddim = thread_data.model.sample( - S=opt_ddim_steps, - conditioning=c, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) + # samples, _ = sampler.sample(S=opt.steps, + # conditioning=c, + # batch_size=opt.n_samples, + # shape=shape, + # verbose=False, + # unconditional_guidance_scale=opt.scale, + # unconditional_conditioning=uc, + # eta=opt.ddim_eta, + # x_T=start_code) + + if thread_data.test_sd2: + from ldm.models.diffusion.ddim import DDIMSampler + from ldm.models.diffusion.plms import PLMSSampler + + shape = [opt_C, opt_H // opt_f, opt_W // opt_f] + + if sampler_name == 'plms': + sampler = PLMSSampler(thread_data.model) + elif sampler_name == 'ddim': + sampler = DDIMSampler(thread_data.model) + + samples_ddim = sampler.sample( + S=opt_ddim_steps, + conditioning=c, + batch_size=opt_n_samples, + seed=opt_seed, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt_scale, + unconditional_conditioning=uc, + eta=opt_ddim_eta, + x_T=start_code, + img_callback=img_callback, + mask=mask, + sampler = sampler_name, + ) + else: + samples_ddim = thread_data.model.sample( + S=opt_ddim_steps, + conditioning=c, + seed=opt_seed, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt_scale, + unconditional_conditioning=uc, + eta=opt_ddim_eta, + x_T=start_code, + img_callback=img_callback, + mask=mask, + sampler = sampler_name, + ) yield from samples_ddim def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):