New samplers for txt2img: "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"

This commit is contained in:
cmdr2 2022-09-23 00:19:05 +05:30
parent b0c15bc430
commit 956b3d89db
5 changed files with 291 additions and 96 deletions

View File

@ -304,7 +304,7 @@
<div id="server-status-color">&nbsp;</div> <div id="server-status-color">&nbsp;</div>
<span id="server-status-msg">Stable Diffusion is starting..</span> <span id="server-status-msg">Stable Diffusion is starting..</span>
</div> </div>
<h1>Stable Diffusion UI <small>v2.14 <span id="updateBranchLabel"></span></small></h1> <h1>Stable Diffusion UI <small>v2.15 <span id="updateBranchLabel"></span></small></h1>
</div> </div>
<div id="editor-inputs"> <div id="editor-inputs">
<div id="editor-inputs-prompt" class="row"> <div id="editor-inputs-prompt" class="row">
@ -353,6 +353,18 @@
<br/> <br/>
<li><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li> <li><label for="seed">Seed:</label> <input id="seed" name="seed" size="10" value="30000"> <input id="random_seed" name="random_seed" type="checkbox" checked> <label for="random_seed">Random Image</label></li>
<li><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="4"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="4"> (images at once)</li> <li><label for="num_outputs_total">Number of images to make:</label> <input id="num_outputs_total" name="num_outputs_total" value="1" size="4"> <label for="num_outputs_parallel">Generate in parallel:</label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="4"> (images at once)</li>
<li id="samplerSelection"><label for="sampler">Sampler:</label>
<select id="sampler" name="sampler">
<option value="plms" selected>plms</option>
<option value="ddim">ddim</option>
<option value="heun">heun</option>
<option value="euler">euler</option>
<option value="euler_a">euler_a</option>
<option value="dpm2">dpm2</option>
<option value="dpm2_a">dpm2_a</option>
<option value="lms">lms</option>
</select>
</li>
<li><label for="width">Width:</label> <li><label for="width">Width:</label>
<select id="width" name="width" value="512"> <select id="width" name="width" value="512">
<option value="128">128 (*)</option> <option value="128">128 (*)</option>
@ -488,6 +500,8 @@ let diskPathField = document.querySelector('#diskPath')
let useBetaChannelField = document.querySelector("#use_beta_channel") let useBetaChannelField = document.querySelector("#use_beta_channel")
let promptStrengthSlider = document.querySelector('#prompt_strength_slider') let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
let promptStrengthField = document.querySelector('#prompt_strength') let promptStrengthField = document.querySelector('#prompt_strength')
let samplerField = document.querySelector('#sampler')
let samplerSelectionContainer = document.querySelector("#samplerSelection")
let useFaceCorrectionField = document.querySelector("#use_face_correction") let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale") let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model") let upscaleModelField = document.querySelector("#upscale_model")
@ -1007,6 +1021,10 @@ async function makeImage() {
if (maskSetting.checked) { if (maskSetting.checked) {
reqBody['mask'] = inpaintingEditor.getImg() reqBody['mask'] = inpaintingEditor.getImg()
} }
reqBody['sampler'] = 'ddim'
} else {
reqBody['sampler'] = samplerField.value
} }
if (saveToDiskField.checked && diskPathField.value.trim() !== '') { if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
@ -1275,6 +1293,7 @@ function showInitImagePreview() {
initImagePreviewContainer.style.display = 'block' initImagePreviewContainer.style.display = 'block'
inpaintingEditorContainer.style.display = 'none' inpaintingEditorContainer.style.display = 'none'
promptStrengthContainer.style.display = 'block' promptStrengthContainer.style.display = 'block'
samplerSelectionContainer.style.display = 'none'
// maskSetting.checked = false // maskSetting.checked = false
}) })
@ -1306,6 +1325,7 @@ initImageClearBtn.addEventListener('click', function() {
// maskSetting.style.display = 'none' // maskSetting.style.display = 'none'
promptStrengthContainer.style.display = 'none' promptStrengthContainer.style.display = 'none'
samplerSelectionContainer.style.display = 'block'
}) })
maskSetting.addEventListener('click', function() { maskSetting.addEventListener('click', function() {

View File

@ -12,6 +12,7 @@ class Request:
height: int = 512 height: int = 512
seed: int = 42 seed: int = 42
prompt_strength: float = 0.8 prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False # allow_nsfw: bool = False
precision: str = "autocast" # or "full" precision: str = "autocast" # or "full"
save_to_disk_path: str = None save_to_disk_path: str = None
@ -36,6 +37,7 @@ class Request:
"height": self.height, "height": self.height,
"seed": self.seed, "seed": self.seed,
"prompt_strength": self.prompt_strength, "prompt_strength": self.prompt_strength,
"sampler": self.sampler,
"use_face_correction": self.use_face_correction, "use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale, "use_upscale": self.use_upscale,
} }
@ -46,6 +48,7 @@ class Request:
prompt: {self.prompt} prompt: {self.prompt}
seed: {self.seed} seed: {self.seed}
num_inference_steps: {self.num_inference_steps} num_inference_steps: {self.num_inference_steps}
sampler: {self.sampler}
guidance_scale: {self.guidance_scale} guidance_scale: {self.guidance_scale}
w: {self.width} w: {self.width}
h: {self.height} h: {self.height}

View File

@ -1,5 +1,5 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index b967b55..75ddd8b 100644 index b967b55..10a7c32 100644
--- a/optimizedSD/ddpm.py --- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py
@@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config @@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config
@ -11,122 +11,303 @@ index b967b55..75ddd8b 100644
def disabled_train(self): def disabled_train(self):
"""Overwrite model.train with this function to make sure train/eval mode """Overwrite model.train with this function to make sure train/eval mode
@@ -485,6 +485,7 @@ class UNet(DDPM): @@ -528,39 +528,46 @@ class UNet(DDPM):
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
+ streaming_callbacks = False,
):
@@ -523,12 +524,15 @@ class UNet(DDPM):
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
+ streaming_callbacks=streaming_callbacks
)
elif sampler == "ddim": elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False) - mask = mask,init_latent=x_T,use_original_steps=False)
+ mask = mask,init_latent=x_T,use_original_steps=False, + mask = mask,init_latent=x_T,use_original_steps=False,
+ callback=callback, img_callback=img_callback, + callback=callback, img_callback=img_callback)
+ streaming_callbacks=streaming_callbacks)
elif sampler == "euler": elif sampler == "euler":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
@@ -555,11 +559,15 @@ class UNet(DDPM): samples = self.euler_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, - unconditional_guidance_scale=unconditional_guidance_scale)
unconditional_guidance_scale=unconditional_guidance_scale) + unconditional_guidance_scale=unconditional_guidance_scale,
+ img_callback=img_callback)
elif sampler == "euler_a":
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
samples = self.euler_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
- unconditional_guidance_scale=unconditional_guidance_scale)
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ img_callback=img_callback)
+ if streaming_callbacks: # this line needs to be right after the sampling() call elif sampler == "dpm2":
+ yield from samples samples = self.dpm_2_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
- unconditional_guidance_scale=unconditional_guidance_scale)
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ img_callback=img_callback)
elif sampler == "heun":
samples = self.heun_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
- unconditional_guidance_scale=unconditional_guidance_scale)
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ img_callback=img_callback)
elif sampler == "dpm2_a":
samples = self.dpm_2_ancestral_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
- unconditional_guidance_scale=unconditional_guidance_scale)
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ img_callback=img_callback)
elif sampler == "lms":
samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning,
- unconditional_guidance_scale=unconditional_guidance_scale)
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ img_callback=img_callback)
+ +
+ yield from samples
if(self.turbo): if(self.turbo):
self.model1.to("cpu") self.model1.to("cpu")
self.model2.to("cpu") self.model2.to("cpu")
- return samples - return samples
+ if not streaming_callbacks: -
+ return samples
@torch.no_grad() @torch.no_grad()
def plms_sampling(self, cond,b, img, def plms_sampling(self, cond,b, img,
@@ -567,7 +575,8 @@ class UNet(DDPM): ddim_use_original_steps=False,
callback=None, quantize_denoised=False, @@ -599,10 +606,10 @@ class UNet(DDPM):
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ streaming_callbacks=False):
device = self.betas.device
timesteps = self.ddim_timesteps
@@ -599,10 +608,21 @@ class UNet(DDPM):
old_eps.append(e_t) old_eps.append(e_t)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)
- if callback: callback(i) - if callback: callback(i)
- if img_callback: img_callback(pred_x0, i) - if img_callback: img_callback(pred_x0, i)
- + if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
- return img - return img
+ if callback: + yield from img_callback(img, len(iterator)-1)
+ if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(pred_x0, i)
+ else:
+ img_callback(pred_x0, i)
+
+ if streaming_callbacks and img_callback:
+ yield from img_callback(img, len(iterator)-1)
+ else:
+ return img
@torch.no_grad() @torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -706,7 +726,9 @@ class UNet(DDPM): @@ -706,7 +713,8 @@ class UNet(DDPM):
@torch.no_grad() @torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False): - mask = None,init_latent=None,use_original_steps=False):
+ mask = None,init_latent=None,use_original_steps=False, + mask = None,init_latent=None,use_original_steps=False,
+ callback=None, img_callback=None, + callback=None, img_callback=None):
+ streaming_callbacks=False):
timesteps = self.ddim_timesteps timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
@@ -730,10 +752,24 @@ class UNet(DDPM): @@ -730,10 +738,13 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning)
+ if callback: + if callback: yield from callback(i)
+ if streaming_callbacks: + if img_callback: yield from img_callback(x_dec, i)
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(x_dec, i)
+ else:
+ img_callback(x_dec, i)
+ +
if mask is not None: if mask is not None:
- return x0 * mask + (1. - mask) * x_dec - return x0 * mask + (1. - mask) * x_dec
+ x_dec = x0 * mask + (1. - mask) * x_dec + x_dec = x0 * mask + (1. - mask) * x_dec
- return x_dec - return x_dec
+ if streaming_callbacks and img_callback: + yield from img_callback(x_dec, len(iterator)-1)
+ yield from img_callback(x_dec, len(iterator)-1)
+ else:
+ return x_dec
@torch.no_grad() @torch.no_grad()
@@ -779,13 +790,16 @@ class UNet(DDPM):
@torch.no_grad()
- def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ def euler_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None,callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
+ img_callback=None):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
cvd = CompVisDenoiser(ac)
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
+ print(f"Running Euler Sampling with {len(sigmas) - 1} timesteps")
+
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
@@ -807,13 +821,18 @@ class UNet(DDPM):
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+
+ if img_callback: yield from img_callback(x, i)
+
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
- return x
+
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
- def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None):
+ def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None,
+ img_callback=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
@@ -822,6 +841,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
+ print(f"Running Euler Ancestral Sampling with {len(sigmas) - 1} timesteps")
+
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
@@ -837,17 +858,22 @@ class UNet(DDPM):
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
+ if img_callback: yield from img_callback(x, i)
+
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
- return x
+
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
- def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ def heun_sampling(self, ac, x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
+ img_callback=None):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
@@ -855,6 +881,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
+ print(f"Running Heun Sampling with {len(sigmas) - 1} timesteps")
+
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
@@ -876,6 +904,9 @@ class UNet(DDPM):
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
+
+ if img_callback: yield from img_callback(x, i)
+
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
@@ -895,11 +926,13 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
- return x
+
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
- def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ def dpm_2_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.,
+ img_callback=None):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
@@ -907,6 +940,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
+ print(f"Running DPM2 Sampling with {len(sigmas) - 1} timesteps")
+
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
@@ -924,7 +959,7 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigma_hat, denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
@@ -945,11 +980,13 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
- return x
+
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
- def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None):
+ def dpm_2_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None,
+ img_callback=None):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
@@ -957,6 +994,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
+ print(f"Running DPM2 Ancestral Sampling with {len(sigmas) - 1} timesteps")
+
s_in = x.new_ones([x.shape[0]]).half()
for i in trange(len(sigmas) - 1, disable=disable):
@@ -973,6 +1012,9 @@ class UNet(DDPM):
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+
+ if img_callback: yield from img_callback(x, i)
+
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
@@ -993,11 +1035,13 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
- return x
+
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
- def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4):
+ def lms_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, order=4,
+ img_callback=None):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
@@ -1005,6 +1049,8 @@ class UNet(DDPM):
sigmas = cvd.get_sigmas(S)
x = x*sigmas[0]
+ print(f"Running LMS Sampling with {len(sigmas) - 1} timesteps")
+
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
@@ -1017,6 +1063,7 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
@@ -1027,4 +1074,5 @@ class UNet(DDPM):
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
- return x
+
+ yield from img_callback(x, len(sigmas)-1)
diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py
index abc3098..7a32ffe 100644 index abc3098..7a32ffe 100644
--- a/optimizedSD/openaimodelSplit.py --- a/optimizedSD/openaimodelSplit.py

View File

@ -275,6 +275,7 @@ def do_mk_img(req: Request):
opt_use_upscale = req.use_upscale opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image opt_show_only_filtered = req.show_only_filtered_image
opt_format = 'png' opt_format = 'png'
opt_sampler_name = req.sampler
print(req.to_string(), '\n device', device) print(req.to_string(), '\n device', device)
@ -399,12 +400,11 @@ def do_mk_img(req: Request):
# run the handler # run the handler
try: try:
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask) x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask) x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
if req.stream_progress_updates: yield from x_samples
yield from x_samples
x_samples = partial_x_samples x_samples = partial_x_samples
except UserInitiatedStop: except UserInitiatedStop:
@ -443,7 +443,7 @@ def do_mk_img(req: Request):
if return_orig_img: if return_orig_img:
save_image(img, img_out_path) save_image(img, img_out_path)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale) save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name)
if return_orig_img: if return_orig_img:
img_data = img_to_base64_str(img) img_data = img_to_base64_str(img)
@ -496,10 +496,7 @@ def do_mk_img(req: Request):
print('Task completed') print('Task completed')
if req.stream_progress_updates: yield json.dumps(res.json())
yield json.dumps(res.json())
else:
return res
def save_image(img, img_out_path): def save_image(img, img_out_path):
try: try:
@ -507,8 +504,8 @@ def save_image(img, img_out_path):
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale): def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}" metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}"
try: try:
with open(meta_out_path, 'w') as f: with open(meta_out_path, 'w') as f:
@ -516,7 +513,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask): def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu": if device != "cpu":
@ -536,17 +533,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
eta=opt_ddim_eta, eta=opt_ddim_eta,
x_T=start_code, x_T=start_code,
img_callback=img_callback, img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask, mask=mask,
sampler = 'plms', sampler = sampler_name,
) )
if streaming_callbacks: yield from samples_ddim
yield from samples_ddim
else:
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent) # encode (scaled latent)
z_enc = model.stochastic_encode( z_enc = model.stochastic_encode(
init_latent, init_latent,
@ -565,16 +558,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_guidance_scale=opt_scale, unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
img_callback=img_callback, img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask, mask=mask,
x_T=x_T, x_T=x_T,
sampler = 'ddim' sampler = 'ddim'
) )
if streaming_callbacks: yield from samples_ddim
yield from samples_ddim
else:
return samples_ddim
def move_fs_to_cpu(): def move_fs_to_cpu():
if device != "cpu": if device != "cpu":

View File

@ -43,6 +43,7 @@ class ImageRequest(BaseModel):
height: int = 512 height: int = 512
seed: int = 42 seed: int = 42
prompt_strength: float = 0.8 prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False # allow_nsfw: bool = False
save_to_disk_path: str = None save_to_disk_path: str = None
turbo: bool = True turbo: bool = True
@ -105,6 +106,7 @@ def image(req : ImageRequest):
r.height = req.height r.height = req.height
r.seed = req.seed r.seed = req.seed
r.prompt_strength = req.prompt_strength r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw # r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo r.turbo = req.turbo
r.use_cpu = req.use_cpu r.use_cpu = req.use_cpu