From 956b3d89db62a853bc7cdac250123ba8c2fc22f5 Mon Sep 17 00:00:00 2001 From: cmdr2 <secondary.cmdr2@gmail.com> Date: Fri, 23 Sep 2022 00:19:05 +0530 Subject: [PATCH] New samplers for txt2img: "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" --- ui/index.html | 22 +- ui/sd_internal/__init__.py | 3 + ui/sd_internal/ddim_callback.patch | 323 ++++++++++++++++++++++------- ui/sd_internal/runtime.py | 37 ++-- ui/server.py | 2 + 5 files changed, 291 insertions(+), 96 deletions(-) diff --git a/ui/index.html b/ui/index.html index e8afeae1..dd1924b0 100644 --- a/ui/index.html +++ b/ui/index.html @@ -304,7 +304,7 @@ <div id="server-status-color"> </div> <span id="server-status-msg">Stable Diffusion is starting..</span> </div> - <h1>Stable Diffusion UI <small>v2.14 <span id="updateBranchLabel"></span></small></h1> + <h1>Stable Diffusion UI <small>v2.15 <span id="updateBranchLabel"></span></small></h1> </div> <div id="editor-inputs"> <div id="editor-inputs-prompt" class="row"> @@ -353,6 +353,18 @@ <br/> <li><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li> <li><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="4"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="4"> (images at once)</li> + <li id="samplerSelection"><label for="sampler">Sampler:</label> + <select id="sampler" name="sampler"> + <option value="plms" selected>plms</option> + <option value="ddim">ddim</option> + <option value="heun">heun</option> + <option value="euler">euler</option> + <option value="euler_a">euler_a</option> + <option value="dpm2">dpm2</option> + <option value="dpm2_a">dpm2_a</option> + <option value="lms">lms</option> + </select> + </li> <li><label for="width">Width:</label> <select id="width" name="width" value="512"> <option value="128">128 (*)</option> @@ -488,6 +500,8 @@ let diskPathField = document.querySelector('#diskPath') let useBetaChannelField = document.querySelector("#use_beta_channel") let promptStrengthSlider = document.querySelector('#prompt_strength_slider') let promptStrengthField = document.querySelector('#prompt_strength') +let samplerField = document.querySelector('#sampler') +let samplerSelectionContainer = document.querySelector("#samplerSelection") let useFaceCorrectionField = document.querySelector("#use_face_correction") let useUpscalingField = document.querySelector("#use_upscale") let upscaleModelField = document.querySelector("#upscale_model") @@ -1007,6 +1021,10 @@ async function makeImage() { if (maskSetting.checked) { reqBody['mask'] = inpaintingEditor.getImg() } + + reqBody['sampler'] = 'ddim' + } else { + reqBody['sampler'] = samplerField.value } if (saveToDiskField.checked && diskPathField.value.trim() !== '') { @@ -1275,6 +1293,7 @@ function showInitImagePreview() { initImagePreviewContainer.style.display = 'block' inpaintingEditorContainer.style.display = 'none' promptStrengthContainer.style.display = 'block' + samplerSelectionContainer.style.display = 'none' // maskSetting.checked = false }) @@ -1306,6 +1325,7 @@ initImageClearBtn.addEventListener('click', function() { // maskSetting.style.display = 'none' promptStrengthContainer.style.display = 'none' + samplerSelectionContainer.style.display = 'block' }) maskSetting.addEventListener('click', function() { diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index c0b6c6dc..e7630a1a 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -12,6 +12,7 @@ class Request: height: int = 512 seed: int = 42 prompt_strength: float = 0.8 + sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" # allow_nsfw: bool = False precision: str = "autocast" # or "full" save_to_disk_path: str = None @@ -36,6 +37,7 @@ class Request: "height": self.height, "seed": self.seed, "prompt_strength": self.prompt_strength, + "sampler": self.sampler, "use_face_correction": self.use_face_correction, "use_upscale": self.use_upscale, } @@ -46,6 +48,7 @@ class Request: prompt: {self.prompt} seed: {self.seed} num_inference_steps: {self.num_inference_steps} + sampler: {self.sampler} guidance_scale: {self.guidance_scale} w: {self.width} h: {self.height} diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch index 7e40d614..9852aab8 100644 --- a/ui/sd_internal/ddim_callback.patch +++ b/ui/sd_internal/ddim_callback.patch @@ -1,5 +1,5 @@ diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py -index b967b55..75ddd8b 100644 +index b967b55..10a7c32 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config @@ -11,122 +11,303 @@ index b967b55..75ddd8b 100644 def disabled_train(self): """Overwrite model.train with this function to make sure train/eval mode -@@ -485,6 +485,7 @@ class UNet(DDPM): - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, -+ streaming_callbacks = False, - ): - - -@@ -523,12 +524,15 @@ class UNet(DDPM): - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, -+ streaming_callbacks=streaming_callbacks - ) - +@@ -528,39 +528,46 @@ class UNet(DDPM): elif sampler == "ddim": samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - mask = mask,init_latent=x_T,use_original_steps=False) + mask = mask,init_latent=x_T,use_original_steps=False, -+ callback=callback, img_callback=img_callback, -+ streaming_callbacks=streaming_callbacks) ++ callback=callback, img_callback=img_callback) elif sampler == "euler": self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) -@@ -555,11 +559,15 @@ class UNet(DDPM): - samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale) + samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, +- unconditional_guidance_scale=unconditional_guidance_scale) ++ unconditional_guidance_scale=unconditional_guidance_scale, ++ img_callback=img_callback) + elif sampler == "euler_a": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) + samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, +- unconditional_guidance_scale=unconditional_guidance_scale) ++ unconditional_guidance_scale=unconditional_guidance_scale, ++ img_callback=img_callback) -+ if streaming_callbacks: # this line needs to be right after the sampling() call -+ yield from samples + elif sampler == "dpm2": + samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, +- unconditional_guidance_scale=unconditional_guidance_scale) ++ unconditional_guidance_scale=unconditional_guidance_scale, ++ img_callback=img_callback) + elif sampler == "heun": + samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, +- unconditional_guidance_scale=unconditional_guidance_scale) ++ unconditional_guidance_scale=unconditional_guidance_scale, ++ img_callback=img_callback) + + elif sampler == "dpm2_a": + samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, +- unconditional_guidance_scale=unconditional_guidance_scale) ++ unconditional_guidance_scale=unconditional_guidance_scale, ++ img_callback=img_callback) + + + elif sampler == "lms": + samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, +- unconditional_guidance_scale=unconditional_guidance_scale) ++ unconditional_guidance_scale=unconditional_guidance_scale, ++ img_callback=img_callback) + ++ yield from samples + if(self.turbo): self.model1.to("cpu") self.model2.to("cpu") - return samples -+ if not streaming_callbacks: -+ return samples - +- @torch.no_grad() def plms_sampling(self, cond,b, img, -@@ -567,7 +575,8 @@ class UNet(DDPM): - callback=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, -- unconditional_guidance_scale=1., unconditional_conditioning=None,): -+ unconditional_guidance_scale=1., unconditional_conditioning=None, -+ streaming_callbacks=False): - - device = self.betas.device - timesteps = self.ddim_timesteps -@@ -599,10 +608,21 @@ class UNet(DDPM): + ddim_use_original_steps=False, +@@ -599,10 +606,10 @@ class UNet(DDPM): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) -- ++ if callback: yield from callback(i) ++ if img_callback: yield from img_callback(pred_x0, i) + - return img -+ if callback: -+ if streaming_callbacks: -+ yield from callback(i) -+ else: -+ callback(i) -+ if img_callback: -+ if streaming_callbacks: -+ yield from img_callback(pred_x0, i) -+ else: -+ img_callback(pred_x0, i) -+ -+ if streaming_callbacks and img_callback: -+ yield from img_callback(img, len(iterator)-1) -+ else: -+ return img ++ yield from img_callback(img, len(iterator)-1) @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -@@ -706,7 +726,9 @@ class UNet(DDPM): +@@ -706,7 +713,8 @@ class UNet(DDPM): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - mask = None,init_latent=None,use_original_steps=False): + mask = None,init_latent=None,use_original_steps=False, -+ callback=None, img_callback=None, -+ streaming_callbacks=False): ++ callback=None, img_callback=None): timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] -@@ -730,10 +752,24 @@ class UNet(DDPM): +@@ -730,10 +738,13 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) -+ if callback: -+ if streaming_callbacks: -+ yield from callback(i) -+ else: -+ callback(i) -+ if img_callback: -+ if streaming_callbacks: -+ yield from img_callback(x_dec, i) -+ else: -+ img_callback(x_dec, i) ++ if callback: yield from callback(i) ++ if img_callback: yield from img_callback(x_dec, i) + if mask is not None: - return x0 * mask + (1. - mask) * x_dec + x_dec = x0 * mask + (1. - mask) * x_dec - return x_dec -+ if streaming_callbacks and img_callback: -+ yield from img_callback(x_dec, len(iterator)-1) -+ else: -+ return x_dec ++ yield from img_callback(x_dec, len(iterator)-1) @torch.no_grad() +@@ -779,13 +790,16 @@ class UNet(DDPM): + + + @torch.no_grad() +- def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): ++ def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., ++ img_callback=None): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + cvd = CompVisDenoiser(ac) + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + ++ print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps") ++ + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. +@@ -807,13 +821,18 @@ class UNet(DDPM): + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) ++ ++ if img_callback: yield from img_callback(x, i) ++ + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt +- return x ++ ++ yield from img_callback(x, len(sigmas)-1) + + @torch.no_grad() +- def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None): ++ def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, ++ img_callback=None): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + +@@ -822,6 +841,8 @@ class UNet(DDPM): + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + ++ print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps") ++ + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + +@@ -837,17 +858,22 @@ class UNet(DDPM): + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) ++ ++ if img_callback: yield from img_callback(x, i) ++ + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + torch.randn_like(x) * sigma_up +- return x ++ ++ yield from img_callback(x, len(sigmas)-1) + + + + @torch.no_grad() +- def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): ++ def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., ++ img_callback=None): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + +@@ -855,6 +881,8 @@ class UNet(DDPM): + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + ++ print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps") ++ + + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): +@@ -876,6 +904,9 @@ class UNet(DDPM): + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) ++ ++ if img_callback: yield from img_callback(x, i) ++ + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method +@@ -895,11 +926,13 @@ class UNet(DDPM): + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt +- return x ++ ++ yield from img_callback(x, len(sigmas)-1) + + + @torch.no_grad() +- def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): ++ def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., ++ img_callback=None): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + +@@ -907,6 +940,8 @@ class UNet(DDPM): + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + ++ print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps") ++ + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. +@@ -924,7 +959,7 @@ class UNet(DDPM): + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + +- ++ if img_callback: yield from img_callback(x, i) + + d = to_d(x, sigma_hat, denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule +@@ -945,11 +980,13 @@ class UNet(DDPM): + + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 +- return x ++ ++ yield from img_callback(x, len(sigmas)-1) + + + @torch.no_grad() +- def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None): ++ def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, ++ img_callback=None): + """Ancestral sampling with DPM-Solver inspired second-order steps.""" + extra_args = {} if extra_args is None else extra_args + +@@ -957,6 +994,8 @@ class UNet(DDPM): + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + ++ print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps") ++ + s_in = x.new_ones([x.shape[0]]).half() + for i in trange(len(sigmas) - 1, disable=disable): + +@@ -973,6 +1012,9 @@ class UNet(DDPM): + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) ++ ++ if img_callback: yield from img_callback(x, i) ++ + d = to_d(x, sigmas[i], denoised) + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 +@@ -993,11 +1035,13 @@ class UNet(DDPM): + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + torch.randn_like(x) * sigma_up +- return x ++ ++ yield from img_callback(x, len(sigmas)-1) + + + @torch.no_grad() +- def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4): ++ def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4, ++ img_callback=None): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + +@@ -1005,6 +1049,8 @@ class UNet(DDPM): + sigmas = cvd.get_sigmas(S) + x = x*sigmas[0] + ++ print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps") ++ + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + +@@ -1017,6 +1063,7 @@ class UNet(DDPM): + e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) + denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + ++ if img_callback: yield from img_callback(x, i) + + d = to_d(x, sigmas[i], denoised) + ds.append(d) +@@ -1027,4 +1074,5 @@ class UNet(DDPM): + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) +- return x ++ ++ yield from img_callback(x, len(sigmas)-1) diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py index abc3098..7a32ffe 100644 --- a/optimizedSD/openaimodelSplit.py diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 9434414e..9fa88e5b 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -275,6 +275,7 @@ def do_mk_img(req: Request): opt_use_upscale = req.use_upscale opt_show_only_filtered = req.show_only_filtered_image opt_format = 'png' + opt_sampler_name = req.sampler print(req.to_string(), '\n device', device) @@ -399,12 +400,11 @@ def do_mk_img(req: Request): # run the handler try: if handler == _txt2img: - x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask) + x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name) else: - x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask) + x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask) - if req.stream_progress_updates: - yield from x_samples + yield from x_samples x_samples = partial_x_samples except UserInitiatedStop: @@ -443,7 +443,7 @@ def do_mk_img(req: Request): if return_orig_img: save_image(img, img_out_path) - save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale) + save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name) if return_orig_img: img_data = img_to_base64_str(img) @@ -496,10 +496,7 @@ def do_mk_img(req: Request): print('Task completed') - if req.stream_progress_updates: - yield json.dumps(res.json()) - else: - return res + yield json.dumps(res.json()) def save_image(img, img_out_path): try: @@ -507,8 +504,8 @@ def save_image(img, img_out_path): except: print('could not save the file', traceback.format_exc()) -def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale): - metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}" +def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name): + metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}" try: with open(meta_out_path, 'w') as f: @@ -516,7 +513,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps except: print('could not save the file', traceback.format_exc()) -def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask): +def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] if device != "cpu": @@ -536,17 +533,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, eta=opt_ddim_eta, x_T=start_code, img_callback=img_callback, - streaming_callbacks=streaming_callbacks, mask=mask, - sampler = 'plms', + sampler = sampler_name, ) - if streaming_callbacks: - yield from samples_ddim - else: - return samples_ddim + yield from samples_ddim -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask): +def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): # encode (scaled latent) z_enc = model.stochastic_encode( init_latent, @@ -565,16 +558,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o unconditional_guidance_scale=opt_scale, unconditional_conditioning=uc, img_callback=img_callback, - streaming_callbacks=streaming_callbacks, mask=mask, x_T=x_T, sampler = 'ddim' ) - if streaming_callbacks: - yield from samples_ddim - else: - return samples_ddim + yield from samples_ddim def move_fs_to_cpu(): if device != "cpu": diff --git a/ui/server.py b/ui/server.py index 2ca0f88d..ad3b1ae0 100644 --- a/ui/server.py +++ b/ui/server.py @@ -43,6 +43,7 @@ class ImageRequest(BaseModel): height: int = 512 seed: int = 42 prompt_strength: float = 0.8 + sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" # allow_nsfw: bool = False save_to_disk_path: str = None turbo: bool = True @@ -105,6 +106,7 @@ def image(req : ImageRequest): r.height = req.height r.seed = req.seed r.prompt_strength = req.prompt_strength + r.sampler = req.sampler # r.allow_nsfw = req.allow_nsfw r.turbo = req.turbo r.use_cpu = req.use_cpu