Show the progress percentage while generating images

This commit is contained in:
cmdr2 2022-09-14 16:52:03 +05:30
parent 3b47eb3b07
commit 9f48d5e5ff
5 changed files with 192 additions and 26 deletions

View File

@ -60,6 +60,9 @@
#outputMsg {
font-size: small;
}
#progressBar {
font-size: small;
}
#footer {
font-size: small;
padding-left: 10pt;
@ -397,6 +400,7 @@
<div id="preview-prompt">Type a prompt and press the "Make Image" button.<br/><br/>You can set an "Initial Image" if you want to guide the AI.<br/><br/>You can also add modifiers like "Realistic", "Pencil Sketch", "ArtStation" etc by browsing through the "Image Modifiers" section and selecting the desired modifiers.<br/><br/>Click "Advanced Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)</div>
<div id="outputMsg"></div>
<div id="progressBar"></div>
<div id="current-images" class="img-preview">
</div>
</div>
@ -483,6 +487,7 @@ let previewPrompt = document.querySelector('#preview-prompt')
let showConfigToggle = document.querySelector('#configToggleBtn')
// let configBox = document.querySelector('#config')
let outputMsg = document.querySelector('#outputMsg')
let progressBar = document.querySelector("#progressBar")
let soundToggle = document.querySelector('#sound_toggle')
@ -496,6 +501,7 @@ let serverStatus = 'offline'
let activeTags = []
let lastPromptUsed = ''
let taskStopped = true
let batchesDone = 0
function getLocalStorageItem(key, fallback) {
let item = localStorage.getItem(key)
@ -631,7 +637,7 @@ async function healthCheck() {
}
// makes a single image. don't call this directly, use makeImage() instead
async function doMakeImage(reqBody) {
async function doMakeImage(reqBody, batchCount) {
if (taskStopped) {
return
}
@ -648,6 +654,41 @@ async function doMakeImage(reqBody) {
body: JSON.stringify(reqBody)
})
let reader = res.body.getReader()
let textDecoder = new TextDecoder()
let finalJSON = ''
while (true) {
try {
const {value, done} = await reader.read()
if (done) {
break
}
let jsonStr = textDecoder.decode(value)
try {
let stepUpdate = JSON.parse(jsonStr)
if (stepUpdate.step !== undefined) {
let batchSize = parseInt(reqBody['num_inference_steps'])
let overallStepCount = stepUpdate.step + batchesDone * batchSize
let totalSteps = batchCount * batchSize
let percent = 100 * (overallStepCount / totalSteps)
percent = percent.toFixed(0)
outputMsg.innerHTML = `Batch ${batchesDone+1} of ${batchCount}`
progressBar.innerHTML = `Generating image: ${percent}%`
progressBar.style.display = 'block'
}
} catch (e) {
finalJSON += jsonStr
}
} catch (e) {
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
res = undefined
throw e
}
}
if (res.status != 200) {
if (serverStatus === 'online') {
logError('Stable Diffusion had an error: ' + await res.text() + '. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
@ -655,8 +696,10 @@ async function doMakeImage(reqBody) {
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed.", res)
}
res = undefined
progressBar.style.display = 'none'
} else {
res = await res.json()
res = JSON.parse(finalJSON)
progressBar.style.display = 'none'
if (res.status !== 'succeeded') {
let msg = ''
@ -680,7 +723,9 @@ async function doMakeImage(reqBody) {
}
} catch (e) {
console.log('request error', e)
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
setStatus('request', 'error', 'error')
progressBar.style.display = 'none'
}
if (!res) {
@ -789,7 +834,7 @@ async function makeImage() {
let validation = validateInput()
if (validation['isValid']) {
outputMsg.innerHTML = 'Fetching..'
outputMsg.innerHTML = 'Starting..'
} else {
if (validation['error']) {
logError(validation['error'])
@ -806,6 +851,7 @@ async function makeImage() {
stopImageBtn.style.display = 'block'
taskStopped = false
batchesDone = 0
let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value))
let numOutputsTotal = parseInt(numOutputsTotalField.value)
@ -831,7 +877,8 @@ async function makeImage() {
// allow_nsfw: allowNSFWField.checked,
turbo: turboField.checked,
use_cpu: useCPUField.checked,
use_full_precision: useFullPrecisionField.checked
use_full_precision: useFullPrecisionField.checked,
stream_progress_updates: true
}
if (IMAGE_REGEX.test(initImagePreview.src)) {
@ -867,7 +914,8 @@ async function makeImage() {
for (let i = 0; i < batchCount; i++) {
reqBody['seed'] = seed + (i * batchSize)
let success = await doMakeImage(reqBody)
let success = await doMakeImage(reqBody, batchCount)
batchesDone++
if (success) {
outputMsg.innerHTML = 'Processed batch ' + (i+1) + '/' + batchCount

View File

@ -21,6 +21,8 @@ class Request:
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False
stream_progress_updates: bool = False
def json(self):
return {
"prompt": self.prompt,
@ -50,7 +52,9 @@ class Request:
use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
show_only_filtered_image: {self.show_only_filtered_image}'''
show_only_filtered_image: {self.show_only_filtered_image}
stream_progress_updates: {self.stream_progress_updates}'''
class Image:
data: str # base64

View File

@ -1,34 +1,121 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index dcf7901..1f99adc 100644
index dcf7901..4028a70 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -528,7 +528,8 @@ class UNet(DDPM):
@@ -485,6 +485,7 @@ class UNet(DDPM):
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
+ streaming_callbacks = False,
):
@@ -523,12 +524,15 @@ class UNet(DDPM):
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
+ streaming_callbacks=streaming_callbacks
)
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False)
+ mask = mask,init_latent=x_T,use_original_steps=False,
+ callback=callback, img_callback=img_callback)
+ callback=callback, img_callback=img_callback,
+ streaming_callbacks=streaming_callbacks)
# elif sampler == "euler":
# cvd = CompVisDenoiser(self.alphas_cumprod)
@@ -687,7 +688,8 @@ class UNet(DDPM):
@@ -536,11 +540,15 @@ class UNet(DDPM):
# samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning,
# unconditional_guidance_scale=unconditional_guidance_scale)
+ if streaming_callbacks: # this line needs to be right after the sampling() call
+ yield from samples
+
if(self.turbo):
self.model1.to("cpu")
self.model2.to("cpu")
- return samples
+ if not streaming_callbacks:
+ return samples
@torch.no_grad()
def plms_sampling(self, cond,b, img,
@@ -548,7 +556,8 @@ class UNet(DDPM):
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ streaming_callbacks=False):
device = self.betas.device
timesteps = self.ddim_timesteps
@@ -580,10 +589,22 @@ class UNet(DDPM):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
- return img
+ if callback:
+ if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(pred_x0, i)
+ else:
+ img_callback(pred_x0, i)
+
+ if streaming_callbacks and img_callback:
+ yield from img_callback(img, len(iterator)-1)
+ else:
+ return img
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -687,7 +708,9 @@ class UNet(DDPM):
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False):
+ mask = None,init_latent=None,use_original_steps=False,
+ callback=None, img_callback=None):
+ callback=None, img_callback=None,
+ streaming_callbacks=False):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
@@ -710,6 +712,9 @@ class UNet(DDPM):
@@ -710,11 +733,25 @@ class UNet(DDPM):
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
+
+ if callback: callback(i)
+ if img_callback: img_callback(x_dec, i)
+ if callback:
+ if streaming_callbacks:
+ yield from callback(i)
+ else:
+ callback(i)
+ if img_callback:
+ if streaming_callbacks:
+ yield from img_callback(x_dec, i)
+ else:
+ img_callback(x_dec, i)
if mask is not None:
return x0 * mask + (1. - mask) * x_dec
- return x0 * mask + (1. - mask) * x_dec
+ x_dec = x0 * mask + (1. - mask) * x_dec
- return x_dec
+ if streaming_callbacks and img_callback:
+ yield from img_callback(x_dec, len(iterator)-1)
+ else:
+ return x_dec
@torch.no_grad()

View File

@ -1,3 +1,4 @@
import json
import os, re
import traceback
import torch
@ -332,15 +333,23 @@ def mk_img(req: Request):
partial_x_samples = x_samples
if req.stream_progress_updates:
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
if req.stream_progress_updates:
yield from x_samples
x_samples = partial_x_samples
except UserInitiatedStop:
if partial_x_samples is None:
continue
@ -421,6 +430,9 @@ def mk_img(req: Request):
del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
return res
def save_image(img, img_out_path):
@ -438,7 +450,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -458,12 +470,16 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
sampler = 'plms',
)
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -480,9 +496,13 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
sampler = 'ddim'
)
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim
def gc():

View File

@ -15,7 +15,7 @@ CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
from fastapi import FastAPI, HTTPException
from starlette.responses import FileResponse
from starlette.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
import logging
@ -50,6 +50,8 @@ class ImageRequest(BaseModel):
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
show_only_filtered_image: bool = False
stream_progress_updates: bool = False
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
@ -106,9 +108,14 @@ def image(req : ImageRequest):
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
try:
res: Response = runtime.mk_img(r)
r.stream_progress_updates = req.stream_progress_updates
try:
res = runtime.mk_img(r)
if r.stream_progress_updates:
return StreamingResponse(res, media_type='application/json')
else:
return res.json()
except Exception as e:
print(traceback.format_exc())