Show the progress percentage while generating images

This commit is contained in:
cmdr2 2022-09-14 16:52:03 +05:30
parent 3b47eb3b07
commit 9f48d5e5ff
5 changed files with 192 additions and 26 deletions

View File

@ -60,6 +60,9 @@
#outputMsg { #outputMsg {
font-size: small; font-size: small;
} }
#progressBar {
font-size: small;
}
#footer { #footer {
font-size: small; font-size: small;
padding-left: 10pt; padding-left: 10pt;
@ -397,6 +400,7 @@
<div id="preview-prompt">Type a prompt and press the "Make Image" button.<br/><br/>You can set an "Initial Image" if you want to guide the AI.<br/><br/>You can also add modifiers like "Realistic", "Pencil Sketch", "ArtStation" etc by browsing through the "Image Modifiers" section and selecting the desired modifiers.<br/><br/>Click "Advanced Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)</div> <div id="preview-prompt">Type a prompt and press the "Make Image" button.<br/><br/>You can set an "Initial Image" if you want to guide the AI.<br/><br/>You can also add modifiers like "Realistic", "Pencil Sketch", "ArtStation" etc by browsing through the "Image Modifiers" section and selecting the desired modifiers.<br/><br/>Click "Advanced Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)</div>
<div id="outputMsg"></div> <div id="outputMsg"></div>
<div id="progressBar"></div>
<div id="current-images" class="img-preview"> <div id="current-images" class="img-preview">
</div> </div>
</div> </div>
@ -483,6 +487,7 @@ let previewPrompt = document.querySelector('#preview-prompt')
let showConfigToggle = document.querySelector('#configToggleBtn') let showConfigToggle = document.querySelector('#configToggleBtn')
// let configBox = document.querySelector('#config') // let configBox = document.querySelector('#config')
let outputMsg = document.querySelector('#outputMsg') let outputMsg = document.querySelector('#outputMsg')
let progressBar = document.querySelector("#progressBar")
let soundToggle = document.querySelector('#sound_toggle') let soundToggle = document.querySelector('#sound_toggle')
@ -496,6 +501,7 @@ let serverStatus = 'offline'
let activeTags = [] let activeTags = []
let lastPromptUsed = '' let lastPromptUsed = ''
let taskStopped = true let taskStopped = true
let batchesDone = 0
function getLocalStorageItem(key, fallback) { function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key) let item = localStorage.getItem(key)
@ -631,7 +637,7 @@ async function healthCheck() {
} }
// makes a single image. don't call this directly, use makeImage() instead // makes a single image. don't call this directly, use makeImage() instead
async function doMakeImage(reqBody) { async function doMakeImage(reqBody, batchCount) {
if (taskStopped) { if (taskStopped) {
return return
} }
@ -648,6 +654,41 @@ async function doMakeImage(reqBody) {
body: JSON.stringify(reqBody) body: JSON.stringify(reqBody)
}) })
let reader = res.body.getReader()
let textDecoder = new TextDecoder()
let finalJSON = ''
while (true) {
try {
const {value, done} = await reader.read()
if (done) {
break
}
let jsonStr = textDecoder.decode(value)
try {
let stepUpdate = JSON.parse(jsonStr)
if (stepUpdate.step !== undefined) {
let batchSize = parseInt(reqBody['num_inference_steps'])
let overallStepCount = stepUpdate.step + batchesDone * batchSize
let totalSteps = batchCount * batchSize
let percent = 100 * (overallStepCount / totalSteps)
percent = percent.toFixed(0)
outputMsg.innerHTML = `Batch ${batchesDone+1} of ${batchCount}`
progressBar.innerHTML = `Generating image: ${percent}%`
progressBar.style.display = 'block'
}
} catch (e) {
finalJSON += jsonStr
}
} catch (e) {
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
res = undefined
throw e
}
}
if (res.status != 200) { if (res.status != 200) {
if (serverStatus === 'online') { if (serverStatus === 'online') {
logError('Stable Diffusion had an error: ' + await res.text() + '. This happens sometimes. Maybe modify the prompt or seed a little bit?', res) logError('Stable Diffusion had an error: ' + await res.text() + '. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
@ -655,8 +696,10 @@ async function doMakeImage(reqBody) {
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed.", res) logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed.", res)
} }
res = undefined res = undefined
progressBar.style.display = 'none'
} else { } else {
res = await res.json() res = JSON.parse(finalJSON)
progressBar.style.display = 'none'
if (res.status !== 'succeeded') { if (res.status !== 'succeeded') {
let msg = '' let msg = ''
@ -680,7 +723,9 @@ async function doMakeImage(reqBody) {
} }
} catch (e) { } catch (e) {
console.log('request error', e) console.log('request error', e)
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
setStatus('request', 'error', 'error') setStatus('request', 'error', 'error')
progressBar.style.display = 'none'
} }
if (!res) { if (!res) {
@ -789,7 +834,7 @@ async function makeImage() {
let validation = validateInput() let validation = validateInput()
if (validation['isValid']) { if (validation['isValid']) {
outputMsg.innerHTML = 'Fetching..' outputMsg.innerHTML = 'Starting..'
} else { } else {
if (validation['error']) { if (validation['error']) {
logError(validation['error']) logError(validation['error'])
@ -806,6 +851,7 @@ async function makeImage() {
stopImageBtn.style.display = 'block' stopImageBtn.style.display = 'block'
taskStopped = false taskStopped = false
batchesDone = 0
let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value)) let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value))
let numOutputsTotal = parseInt(numOutputsTotalField.value) let numOutputsTotal = parseInt(numOutputsTotalField.value)
@ -831,7 +877,8 @@ async function makeImage() {
// allow_nsfw: allowNSFWField.checked, // allow_nsfw: allowNSFWField.checked,
turbo: turboField.checked, turbo: turboField.checked,
use_cpu: useCPUField.checked, use_cpu: useCPUField.checked,
use_full_precision: useFullPrecisionField.checked use_full_precision: useFullPrecisionField.checked,
stream_progress_updates: true
} }
if (IMAGE_REGEX.test(initImagePreview.src)) { if (IMAGE_REGEX.test(initImagePreview.src)) {
@ -867,7 +914,8 @@ async function makeImage() {
for (let i = 0; i < batchCount; i++) { for (let i = 0; i < batchCount; i++) {
reqBody['seed'] = seed + (i * batchSize) reqBody['seed'] = seed + (i * batchSize)
let success = await doMakeImage(reqBody) let success = await doMakeImage(reqBody, batchCount)
batchesDone++
if (success) { if (success) {
outputMsg.innerHTML = 'Processed batch ' + (i+1) + '/' + batchCount outputMsg.innerHTML = 'Processed batch ' + (i+1) + '/' + batchCount

View File

@ -21,6 +21,8 @@ class Request:
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
stream_progress_updates: bool = False
def json(self): def json(self):
return { return {
"prompt": self.prompt, "prompt": self.prompt,
@ -50,7 +52,9 @@ class Request:
use_full_precision: {self.use_full_precision} use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction} use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale} use_upscale: {self.use_upscale}
show_only_filtered_image: {self.show_only_filtered_image}''' show_only_filtered_image: {self.show_only_filtered_image}
stream_progress_updates: {self.stream_progress_updates}'''
class Image: class Image:
data: str # base64 data: str # base64

View File

@ -1,34 +1,121 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index dcf7901..1f99adc 100644 index dcf7901..4028a70 100644
--- a/optimizedSD/ddpm.py --- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py
@@ -528,7 +528,8 @@ class UNet(DDPM): @@ -485,6 +485,7 @@ class UNet(DDPM):
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
+ streaming_callbacks = False,
):
@@ -523,12 +524,15 @@ class UNet(DDPM):
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
+ streaming_callbacks=streaming_callbacks
)
elif sampler == "ddim": elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False) - mask = mask,init_latent=x_T,use_original_steps=False)
+ mask = mask,init_latent=x_T,use_original_steps=False, + mask = mask,init_latent=x_T,use_original_steps=False,
+ callback=callback, img_callback=img_callback) + callback=callback, img_callback=img_callback,
+ streaming_callbacks=streaming_callbacks)
# elif sampler == "euler": # elif sampler == "euler":
# cvd = CompVisDenoiser(self.alphas_cumprod) # cvd = CompVisDenoiser(self.alphas_cumprod)
@@ -687,7 +688,8 @@ class UNet(DDPM): @@ -536,11 +540,15 @@ class UNet(DDPM):
# samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning,
# unconditional_guidance_scale=unconditional_guidance_scale)
+ if streaming_callbacks: # this line needs to be right after the sampling() call
+ yield from samples
+
if(self.turbo):
self.model1.to("cpu")
self.model2.to("cpu")
- return samples
+ if not streaming_callbacks:
+ return samples
@torch.no_grad()
def plms_sampling(self, cond,b, img,
@@ -548,7 +556,8 @@ class UNet(DDPM):
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ streaming_callbacks=False):
device = self.betas.device
timesteps = self.ddim_timesteps
@@ -580,10 +589,22 @@ class UNet(DDPM):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
- return img
+ if callback:
+ if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(pred_x0, i)
+ else:
+ img_callback(pred_x0, i)
+
+ if streaming_callbacks and img_callback:
+ yield from img_callback(img, len(iterator)-1)
+ else:
+ return img
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -687,7 +708,9 @@ class UNet(DDPM):
@torch.no_grad() @torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False): - mask = None,init_latent=None,use_original_steps=False):
+ mask = None,init_latent=None,use_original_steps=False, + mask = None,init_latent=None,use_original_steps=False,
+ callback=None, img_callback=None): + callback=None, img_callback=None,
+ streaming_callbacks=False):
timesteps = self.ddim_timesteps timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start] timesteps = timesteps[:t_start]
@@ -710,6 +712,9 @@ class UNet(DDPM): @@ -710,11 +733,25 @@ class UNet(DDPM):
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning)
+ +
+ if callback: callback(i) + if callback:
+ if img_callback: img_callback(x_dec, i) + if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(x_dec, i)
+ else:
+ img_callback(x_dec, i)
if mask is not None: if mask is not None:
return x0 * mask + (1. - mask) * x_dec - return x0 * mask + (1. - mask) * x_dec
+ x_dec = x0 * mask + (1. - mask) * x_dec
- return x_dec
+ if streaming_callbacks and img_callback:
+ yield from img_callback(x_dec, len(iterator)-1)
+ else:
+ return x_dec
@torch.no_grad()

View File

@ -1,3 +1,4 @@
import json
import os, re import os, re
import traceback import traceback
import torch import torch
@ -332,15 +333,23 @@ def mk_img(req: Request):
partial_x_samples = x_samples partial_x_samples = x_samples
if req.stream_progress_updates:
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
if stop_processing: if stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
# run the handler # run the handler
try: try:
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback) x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback) x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
if req.stream_progress_updates:
yield from x_samples
x_samples = partial_x_samples
except UserInitiatedStop: except UserInitiatedStop:
if partial_x_samples is None: if partial_x_samples is None:
continue continue
@ -421,6 +430,9 @@ def mk_img(req: Request):
del x_samples del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6) print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
return res return res
def save_image(img, img_out_path): def save_image(img, img_out_path):
@ -438,7 +450,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback): def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu": if device != "cpu":
@ -458,12 +470,16 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
eta=opt_ddim_eta, eta=opt_ddim_eta,
x_T=start_code, x_T=start_code,
img_callback=img_callback, img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
sampler = 'plms', sampler = 'plms',
) )
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
# encode (scaled latent) # encode (scaled latent)
z_enc = model.stochastic_encode( z_enc = model.stochastic_encode(
init_latent, init_latent,
@ -480,9 +496,13 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_guidance_scale=opt_scale, unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
img_callback=img_callback, img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
sampler = 'ddim' sampler = 'ddim'
) )
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim return samples_ddim
def gc(): def gc():

View File

@ -15,7 +15,7 @@ CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse from starlette.responses import FileResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
@ -50,6 +50,8 @@ class ImageRequest(BaseModel):
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
stream_progress_updates: bool = False
class SetAppConfigRequest(BaseModel): class SetAppConfigRequest(BaseModel):
update_branch: str = "main" update_branch: str = "main"
@ -106,9 +108,14 @@ def image(req : ImageRequest):
r.use_face_correction = req.use_face_correction r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image r.show_only_filtered_image = req.show_only_filtered_image
try: r.stream_progress_updates = req.stream_progress_updates
res: Response = runtime.mk_img(r)
try:
res = runtime.mk_img(r)
if r.stream_progress_updates:
return StreamingResponse(res, media_type='application/json')
else:
return res.json() return res.json()
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())