From 956b3d89db62a853bc7cdac250123ba8c2fc22f5 Mon Sep 17 00:00:00 2001
From: cmdr2 <secondary.cmdr2@gmail.com>
Date: Fri, 23 Sep 2022 00:19:05 +0530
Subject: [PATCH] New samplers for txt2img: "ddim", "plms", "heun", "euler",
 "euler_a", "dpm2", "dpm2_a", "lms"

---
 ui/index.html                      |  22 +-
 ui/sd_internal/__init__.py         |   3 +
 ui/sd_internal/ddim_callback.patch | 323 ++++++++++++++++++++++-------
 ui/sd_internal/runtime.py          |  37 ++--
 ui/server.py                       |   2 +
 5 files changed, 291 insertions(+), 96 deletions(-)

diff --git a/ui/index.html b/ui/index.html
index e8afeae1..dd1924b0 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -304,7 +304,7 @@
                     <div id="server-status-color">&nbsp;</div>
                     <span id="server-status-msg">Stable Diffusion is starting..</span>
                 </div>
-                <h1>Stable Diffusion UI <small>v2.14 <span id="updateBranchLabel"></span></small></h1>
+                <h1>Stable Diffusion UI <small>v2.15 <span id="updateBranchLabel"></span></small></h1>
             </div>
             <div id="editor-inputs">
                 <div id="editor-inputs-prompt" class="row">
@@ -353,6 +353,18 @@
                     <br/>
                     <li><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li>
                     <li><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="4"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="4"> (images at once)</li>
+                    <li id="samplerSelection"><label for="sampler">Sampler:</label>
+                        <select id="sampler" name="sampler">
+                            <option value="plms" selected>plms</option>
+                            <option value="ddim">ddim</option>
+                            <option value="heun">heun</option>
+                            <option value="euler">euler</option>
+                            <option value="euler_a">euler_a</option>
+                            <option value="dpm2">dpm2</option>
+                            <option value="dpm2_a">dpm2_a</option>
+                            <option value="lms">lms</option>
+                        </select>
+                    </li>
                     <li><label for="width">Width:</label> 
                         <select id="width" name="width" value="512">
                             <option value="128">128 (*)</option>
@@ -488,6 +500,8 @@ let diskPathField = document.querySelector('#diskPath')
 let useBetaChannelField = document.querySelector("#use_beta_channel")
 let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
 let promptStrengthField = document.querySelector('#prompt_strength')
+let samplerField = document.querySelector('#sampler')
+let samplerSelectionContainer = document.querySelector("#samplerSelection")
 let useFaceCorrectionField = document.querySelector("#use_face_correction")
 let useUpscalingField = document.querySelector("#use_upscale")
 let upscaleModelField = document.querySelector("#upscale_model")
@@ -1007,6 +1021,10 @@ async function makeImage() {
         if (maskSetting.checked) {
             reqBody['mask'] = inpaintingEditor.getImg()
         }
+
+        reqBody['sampler'] = 'ddim'
+    } else {
+        reqBody['sampler'] = samplerField.value
     }
 
     if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
@@ -1275,6 +1293,7 @@ function showInitImagePreview() {
         initImagePreviewContainer.style.display = 'block'
         inpaintingEditorContainer.style.display = 'none'
         promptStrengthContainer.style.display = 'block'
+        samplerSelectionContainer.style.display = 'none'
         // maskSetting.checked = false
     })
 
@@ -1306,6 +1325,7 @@ initImageClearBtn.addEventListener('click', function() {
     // maskSetting.style.display = 'none'
 
     promptStrengthContainer.style.display = 'none'
+    samplerSelectionContainer.style.display = 'block'
 })
 
 maskSetting.addEventListener('click', function() {
diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py
index c0b6c6dc..e7630a1a 100644
--- a/ui/sd_internal/__init__.py
+++ b/ui/sd_internal/__init__.py
@@ -12,6 +12,7 @@ class Request:
     height: int = 512
     seed: int = 42
     prompt_strength: float = 0.8
+    sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
     # allow_nsfw: bool = False
     precision: str = "autocast" # or "full"
     save_to_disk_path: str = None
@@ -36,6 +37,7 @@ class Request:
             "height": self.height,
             "seed": self.seed,
             "prompt_strength": self.prompt_strength,
+            "sampler": self.sampler,
             "use_face_correction": self.use_face_correction,
             "use_upscale": self.use_upscale,
         }
@@ -46,6 +48,7 @@ class Request:
     prompt: {self.prompt}
     seed: {self.seed}
     num_inference_steps: {self.num_inference_steps}
+    sampler: {self.sampler}
     guidance_scale: {self.guidance_scale}
     w: {self.width}
     h: {self.height}
diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch
index 7e40d614..9852aab8 100644
--- a/ui/sd_internal/ddim_callback.patch
+++ b/ui/sd_internal/ddim_callback.patch
@@ -1,5 +1,5 @@
 diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
-index b967b55..75ddd8b 100644
+index b967b55..10a7c32 100644
 --- a/optimizedSD/ddpm.py
 +++ b/optimizedSD/ddpm.py
 @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
@@ -11,122 +11,303 @@ index b967b55..75ddd8b 100644
  
  def disabled_train(self):
      """Overwrite model.train with this function to make sure train/eval mode
-@@ -485,6 +485,7 @@ class UNet(DDPM):
-                log_every_t=100,
-                unconditional_guidance_scale=1.,
-                unconditional_conditioning=None,
-+               streaming_callbacks = False,
-                ):
-         
- 
-@@ -523,12 +524,15 @@ class UNet(DDPM):
-                                         log_every_t=log_every_t,
-                                         unconditional_guidance_scale=unconditional_guidance_scale,
-                                         unconditional_conditioning=unconditional_conditioning,
-+                                        streaming_callbacks=streaming_callbacks
-                                         )
- 
+@@ -528,39 +528,46 @@ class UNet(DDPM):
          elif sampler == "ddim":
              samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
                                           unconditional_conditioning=unconditional_conditioning,
 -                                         mask = mask,init_latent=x_T,use_original_steps=False)
 +                                         mask = mask,init_latent=x_T,use_original_steps=False,
-+                                         callback=callback, img_callback=img_callback,
-+                                         streaming_callbacks=streaming_callbacks)
++                                         callback=callback, img_callback=img_callback)
  
          elif sampler == "euler":
              self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
-@@ -555,11 +559,15 @@ class UNet(DDPM):
-             samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
-                                         unconditional_guidance_scale=unconditional_guidance_scale)
+             samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+-                                        unconditional_guidance_scale=unconditional_guidance_scale)
++                                        unconditional_guidance_scale=unconditional_guidance_scale,
++                                        img_callback=img_callback)
+         elif sampler == "euler_a":
+             self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
+             samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+-                                        unconditional_guidance_scale=unconditional_guidance_scale)
++                                        unconditional_guidance_scale=unconditional_guidance_scale,
++                                        img_callback=img_callback)
  
-+        if streaming_callbacks: # this line needs to be right after the sampling() call
-+            yield from samples
+         elif sampler == "dpm2":
+             samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+-                                        unconditional_guidance_scale=unconditional_guidance_scale)
++                                        unconditional_guidance_scale=unconditional_guidance_scale,
++                                        img_callback=img_callback)
+         elif sampler == "heun":
+             samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+-                                        unconditional_guidance_scale=unconditional_guidance_scale)
++                                        unconditional_guidance_scale=unconditional_guidance_scale,
++                                        img_callback=img_callback)
+ 
+         elif sampler == "dpm2_a":
+             samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+-                                        unconditional_guidance_scale=unconditional_guidance_scale)
++                                        unconditional_guidance_scale=unconditional_guidance_scale,
++                                        img_callback=img_callback)
+ 
+ 
+         elif sampler == "lms":
+             samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
+-                                        unconditional_guidance_scale=unconditional_guidance_scale)
++                                        unconditional_guidance_scale=unconditional_guidance_scale,
++                                        img_callback=img_callback)
 +
++        yield from samples
+ 
          if(self.turbo):
              self.model1.to("cpu")
              self.model2.to("cpu")
  
 -        return samples
-+        if not streaming_callbacks:
-+            return samples
- 
+-
      @torch.no_grad()
      def plms_sampling(self, cond,b, img,
-@@ -567,7 +575,8 @@ class UNet(DDPM):
-                       callback=None, quantize_denoised=False,
-                       mask=None, x0=None, img_callback=None, log_every_t=100,
-                       temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
--                      unconditional_guidance_scale=1., unconditional_conditioning=None,):
-+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
-+                      streaming_callbacks=False):
-         
-         device = self.betas.device
-         timesteps = self.ddim_timesteps
-@@ -599,10 +608,21 @@ class UNet(DDPM):
+                       ddim_use_original_steps=False,
+@@ -599,10 +606,10 @@ class UNet(DDPM):
              old_eps.append(e_t)
              if len(old_eps) >= 4:
                  old_eps.pop(0)
 -            if callback: callback(i)
 -            if img_callback: img_callback(pred_x0, i)
--
++            if callback: yield from callback(i)
++            if img_callback: yield from img_callback(pred_x0, i)
+ 
 -        return img
-+            if callback:
-+                if streaming_callbacks:
-+                    yield from callback(i)
-+                else:
-+                    callback(i)
-+            if img_callback:
-+                if streaming_callbacks:
-+                    yield from img_callback(pred_x0, i)
-+                else:
-+                    img_callback(pred_x0, i)
-+
-+        if streaming_callbacks and img_callback:
-+            yield from img_callback(img, len(iterator)-1)
-+        else:
-+            return img
++        yield from img_callback(img, len(iterator)-1)
  
      @torch.no_grad()
      def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
-@@ -706,7 +726,9 @@ class UNet(DDPM):
+@@ -706,7 +713,8 @@ class UNet(DDPM):
  
      @torch.no_grad()
      def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
 -               mask = None,init_latent=None,use_original_steps=False):
 +               mask = None,init_latent=None,use_original_steps=False,
-+               callback=None, img_callback=None,
-+               streaming_callbacks=False):
++               callback=None, img_callback=None):
  
          timesteps = self.ddim_timesteps
          timesteps = timesteps[:t_start]
-@@ -730,10 +752,24 @@ class UNet(DDPM):
+@@ -730,10 +738,13 @@ class UNet(DDPM):
                                            unconditional_guidance_scale=unconditional_guidance_scale,
                                            unconditional_conditioning=unconditional_conditioning)
          
-+            if callback:
-+                if streaming_callbacks:
-+                    yield from callback(i)
-+                else:
-+                    callback(i)
-+            if img_callback:
-+                if streaming_callbacks:
-+                    yield from img_callback(x_dec, i)
-+                else:
-+                    img_callback(x_dec, i)
++            if callback: yield from callback(i)
++            if img_callback: yield from img_callback(x_dec, i)
 +
          if mask is not None:
 -            return x0 * mask + (1. - mask) * x_dec
 +            x_dec = x0 * mask + (1. - mask) * x_dec
  
 -        return x_dec
-+        if streaming_callbacks and img_callback:
-+            yield from img_callback(x_dec, len(iterator)-1)
-+        else:
-+            return x_dec
++        yield from img_callback(x_dec, len(iterator)-1)
  
  
      @torch.no_grad()
+@@ -779,13 +790,16 @@ class UNet(DDPM):
+ 
+ 
+     @torch.no_grad()
+-    def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
++    def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
++                        img_callback=None):
+         """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
+         extra_args = {} if extra_args is None else extra_args
+         cvd = CompVisDenoiser(ac)
+         sigmas = cvd.get_sigmas(S)
+         x = x*sigmas[0]
+ 
++        print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps")
++
+         s_in = x.new_ones([x.shape[0]]).half()
+         for i in trange(len(sigmas) - 1, disable=disable):
+             gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+@@ -807,13 +821,18 @@ class UNet(DDPM):
+             d = to_d(x, sigma_hat, denoised)
+             if callback is not None:
+                 callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
++
++            if img_callback: yield from img_callback(x, i)
++
+             dt = sigmas[i + 1] - sigma_hat
+             # Euler method
+             x = x + d * dt
+-        return x
++
++        yield from img_callback(x, len(sigmas)-1)
+ 
+     @torch.no_grad()
+-    def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None):
++    def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None,
++                        img_callback=None):
+         """Ancestral sampling with Euler method steps."""
+         extra_args = {} if extra_args is None else extra_args
+ 
+@@ -822,6 +841,8 @@ class UNet(DDPM):
+         sigmas = cvd.get_sigmas(S)
+         x = x*sigmas[0]
+ 
++        print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps")
++
+         s_in = x.new_ones([x.shape[0]]).half()
+         for i in trange(len(sigmas) - 1, disable=disable):
+ 
+@@ -837,17 +858,22 @@ class UNet(DDPM):
+             sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+             if callback is not None:
+                 callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
++
++            if img_callback: yield from img_callback(x, i)
++
+             d = to_d(x, sigmas[i], denoised)
+             # Euler method
+             dt = sigma_down - sigmas[i]
+             x = x + d * dt
+             x = x + torch.randn_like(x) * sigma_up
+-        return x
++
++        yield from img_callback(x, len(sigmas)-1)
+ 
+ 
+ 
+     @torch.no_grad()
+-    def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
++    def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
++                        img_callback=None):
+         """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
+         extra_args = {} if extra_args is None else extra_args
+ 
+@@ -855,6 +881,8 @@ class UNet(DDPM):
+         sigmas = cvd.get_sigmas(S)
+         x = x*sigmas[0]
+ 
++        print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps")
++
+ 
+         s_in = x.new_ones([x.shape[0]]).half()
+         for i in trange(len(sigmas) - 1, disable=disable):
+@@ -876,6 +904,9 @@ class UNet(DDPM):
+             d = to_d(x, sigma_hat, denoised)
+             if callback is not None:
+                 callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
++
++            if img_callback: yield from img_callback(x, i)
++
+             dt = sigmas[i + 1] - sigma_hat
+             if sigmas[i + 1] == 0:
+                 # Euler method
+@@ -895,11 +926,13 @@ class UNet(DDPM):
+                 d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
+                 d_prime = (d + d_2) / 2
+                 x = x + d_prime * dt
+-        return x
++
++        yield from img_callback(x, len(sigmas)-1)
+ 
+ 
+     @torch.no_grad()
+-    def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
++    def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
++                        img_callback=None):
+         """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
+         extra_args = {} if extra_args is None else extra_args
+ 
+@@ -907,6 +940,8 @@ class UNet(DDPM):
+         sigmas = cvd.get_sigmas(S)
+         x = x*sigmas[0]
+ 
++        print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps")
++
+         s_in = x.new_ones([x.shape[0]]).half()
+         for i in trange(len(sigmas) - 1, disable=disable):
+             gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
+@@ -924,7 +959,7 @@ class UNet(DDPM):
+             e_t_uncond, e_t = (x_in  + eps * c_out).chunk(2)
+             denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+ 
+-
++            if img_callback: yield from img_callback(x, i)
+             
+             d = to_d(x, sigma_hat, denoised)
+             # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+@@ -945,11 +980,13 @@ class UNet(DDPM):
+ 
+             d_2 = to_d(x_2, sigma_mid, denoised_2)
+             x = x + d_2 * dt_2
+-        return x
++
++        yield from img_callback(x, len(sigmas)-1)
+ 
+ 
+     @torch.no_grad()
+-    def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None):
++    def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None,
++                        img_callback=None):
+         """Ancestral sampling with DPM-Solver inspired second-order steps."""
+         extra_args = {} if extra_args is None else extra_args
+ 
+@@ -957,6 +994,8 @@ class UNet(DDPM):
+         sigmas = cvd.get_sigmas(S)
+         x = x*sigmas[0]
+ 
++        print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps")
++
+         s_in = x.new_ones([x.shape[0]]).half()
+         for i in trange(len(sigmas) - 1, disable=disable):
+ 
+@@ -973,6 +1012,9 @@ class UNet(DDPM):
+             sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
+             if callback is not None:
+                 callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
++
++            if img_callback: yield from img_callback(x, i)
++
+             d = to_d(x, sigmas[i], denoised)
+             # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
+             sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
+@@ -993,11 +1035,13 @@ class UNet(DDPM):
+             d_2 = to_d(x_2, sigma_mid, denoised_2)
+             x = x + d_2 * dt_2
+             x = x + torch.randn_like(x) * sigma_up
+-        return x
++
++        yield from img_callback(x, len(sigmas)-1)
+ 
+ 
+     @torch.no_grad()
+-    def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4):
++    def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4,
++                        img_callback=None):
+         extra_args = {} if extra_args is None else extra_args
+         s_in = x.new_ones([x.shape[0]])
+ 
+@@ -1005,6 +1049,8 @@ class UNet(DDPM):
+         sigmas = cvd.get_sigmas(S)
+         x = x*sigmas[0]
+ 
++        print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps")
++
+         ds = []
+         for i in trange(len(sigmas) - 1, disable=disable):
+ 
+@@ -1017,6 +1063,7 @@ class UNet(DDPM):
+             e_t_uncond, e_t = (x_in  + eps * c_out).chunk(2)
+             denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+ 
++            if img_callback: yield from img_callback(x, i)
+ 
+             d = to_d(x, sigmas[i], denoised)
+             ds.append(d)
+@@ -1027,4 +1074,5 @@ class UNet(DDPM):
+             cur_order = min(i + 1, order)
+             coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
+             x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+-        return x
++
++        yield from img_callback(x, len(sigmas)-1)
 diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py
 index abc3098..7a32ffe 100644
 --- a/optimizedSD/openaimodelSplit.py
diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py
index 9434414e..9fa88e5b 100644
--- a/ui/sd_internal/runtime.py
+++ b/ui/sd_internal/runtime.py
@@ -275,6 +275,7 @@ def do_mk_img(req: Request):
     opt_use_upscale = req.use_upscale
     opt_show_only_filtered = req.show_only_filtered_image
     opt_format = 'png'
+    opt_sampler_name = req.sampler
 
     print(req.to_string(), '\n    device', device)
 
@@ -399,12 +400,11 @@ def do_mk_img(req: Request):
                     # run the handler
                     try:
                         if handler == _txt2img:
-                            x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
+                            x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
                         else:
-                            x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
+                            x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
 
-                        if req.stream_progress_updates:
-                            yield from x_samples
+                        yield from x_samples
 
                         x_samples = partial_x_samples
                     except UserInitiatedStop:
@@ -443,7 +443,7 @@ def do_mk_img(req: Request):
                             if return_orig_img:
                                 save_image(img, img_out_path)
 
-                            save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale)
+                            save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name)
 
                         if return_orig_img:
                             img_data = img_to_base64_str(img)
@@ -496,10 +496,7 @@ def do_mk_img(req: Request):
 
     print('Task completed')
 
-    if req.stream_progress_updates:
-        yield json.dumps(res.json())
-    else:
-        return res
+    yield json.dumps(res.json())
 
 def save_image(img, img_out_path):
     try:
@@ -507,8 +504,8 @@ def save_image(img, img_out_path):
     except:
         print('could not save the file', traceback.format_exc())
 
-def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale):
-    metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}"
+def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name):
+    metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}"
 
     try:
         with open(meta_out_path, 'w') as f:
@@ -516,7 +513,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
     except:
         print('could not save the file', traceback.format_exc())
 
-def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
+def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
     shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
 
     if device != "cpu":
@@ -536,17 +533,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
         eta=opt_ddim_eta,
         x_T=start_code,
         img_callback=img_callback,
-        streaming_callbacks=streaming_callbacks,
         mask=mask,
-        sampler = 'plms',
+        sampler = sampler_name,
     )
 
-    if streaming_callbacks:
-        yield from samples_ddim
-    else:
-        return samples_ddim
+    yield from samples_ddim
 
-def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
+def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
     # encode (scaled latent)
     z_enc = model.stochastic_encode(
         init_latent,
@@ -565,16 +558,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
         unconditional_guidance_scale=opt_scale,
         unconditional_conditioning=uc,
         img_callback=img_callback,
-        streaming_callbacks=streaming_callbacks,
         mask=mask,
         x_T=x_T,
         sampler = 'ddim'
     )
 
-    if streaming_callbacks:
-        yield from samples_ddim
-    else:
-        return samples_ddim
+    yield from samples_ddim
 
 def move_fs_to_cpu():
     if device != "cpu":
diff --git a/ui/server.py b/ui/server.py
index 2ca0f88d..ad3b1ae0 100644
--- a/ui/server.py
+++ b/ui/server.py
@@ -43,6 +43,7 @@ class ImageRequest(BaseModel):
     height: int = 512
     seed: int = 42
     prompt_strength: float = 0.8
+    sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
     # allow_nsfw: bool = False
     save_to_disk_path: str = None
     turbo: bool = True
@@ -105,6 +106,7 @@ def image(req : ImageRequest):
     r.height = req.height
     r.seed = req.seed
     r.prompt_strength = req.prompt_strength
+    r.sampler = req.sampler
     # r.allow_nsfw = req.allow_nsfw
     r.turbo = req.turbo
     r.use_cpu = req.use_cpu