mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-08 01:14:03 +01:00
Show the progress percentage while generating images
This commit is contained in:
parent
3b47eb3b07
commit
9f48d5e5ff
@ -60,6 +60,9 @@
|
||||
#outputMsg {
|
||||
font-size: small;
|
||||
}
|
||||
#progressBar {
|
||||
font-size: small;
|
||||
}
|
||||
#footer {
|
||||
font-size: small;
|
||||
padding-left: 10pt;
|
||||
@ -397,6 +400,7 @@
|
||||
<div id="preview-prompt">Type a prompt and press the "Make Image" button.<br/><br/>You can set an "Initial Image" if you want to guide the AI.<br/><br/>You can also add modifiers like "Realistic", "Pencil Sketch", "ArtStation" etc by browsing through the "Image Modifiers" section and selecting the desired modifiers.<br/><br/>Click "Advanced Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)</div>
|
||||
|
||||
<div id="outputMsg"></div>
|
||||
<div id="progressBar"></div>
|
||||
<div id="current-images" class="img-preview">
|
||||
</div>
|
||||
</div>
|
||||
@ -483,6 +487,7 @@ let previewPrompt = document.querySelector('#preview-prompt')
|
||||
let showConfigToggle = document.querySelector('#configToggleBtn')
|
||||
// let configBox = document.querySelector('#config')
|
||||
let outputMsg = document.querySelector('#outputMsg')
|
||||
let progressBar = document.querySelector("#progressBar")
|
||||
|
||||
let soundToggle = document.querySelector('#sound_toggle')
|
||||
|
||||
@ -496,6 +501,7 @@ let serverStatus = 'offline'
|
||||
let activeTags = []
|
||||
let lastPromptUsed = ''
|
||||
let taskStopped = true
|
||||
let batchesDone = 0
|
||||
|
||||
function getLocalStorageItem(key, fallback) {
|
||||
let item = localStorage.getItem(key)
|
||||
@ -631,7 +637,7 @@ async function healthCheck() {
|
||||
}
|
||||
|
||||
// makes a single image. don't call this directly, use makeImage() instead
|
||||
async function doMakeImage(reqBody) {
|
||||
async function doMakeImage(reqBody, batchCount) {
|
||||
if (taskStopped) {
|
||||
return
|
||||
}
|
||||
@ -648,6 +654,41 @@ async function doMakeImage(reqBody) {
|
||||
body: JSON.stringify(reqBody)
|
||||
})
|
||||
|
||||
let reader = res.body.getReader()
|
||||
let textDecoder = new TextDecoder()
|
||||
let finalJSON = ''
|
||||
while (true) {
|
||||
try {
|
||||
const {value, done} = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
let jsonStr = textDecoder.decode(value)
|
||||
|
||||
try {
|
||||
let stepUpdate = JSON.parse(jsonStr)
|
||||
|
||||
if (stepUpdate.step !== undefined) {
|
||||
let batchSize = parseInt(reqBody['num_inference_steps'])
|
||||
let overallStepCount = stepUpdate.step + batchesDone * batchSize
|
||||
let totalSteps = batchCount * batchSize
|
||||
let percent = 100 * (overallStepCount / totalSteps)
|
||||
percent = percent.toFixed(0)
|
||||
outputMsg.innerHTML = `Batch ${batchesDone+1} of ${batchCount}`
|
||||
progressBar.innerHTML = `Generating image: ${percent}%`
|
||||
progressBar.style.display = 'block'
|
||||
}
|
||||
} catch (e) {
|
||||
finalJSON += jsonStr
|
||||
}
|
||||
} catch (e) {
|
||||
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
||||
res = undefined
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
if (res.status != 200) {
|
||||
if (serverStatus === 'online') {
|
||||
logError('Stable Diffusion had an error: ' + await res.text() + '. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
||||
@ -655,8 +696,10 @@ async function doMakeImage(reqBody) {
|
||||
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed.", res)
|
||||
}
|
||||
res = undefined
|
||||
progressBar.style.display = 'none'
|
||||
} else {
|
||||
res = await res.json()
|
||||
res = JSON.parse(finalJSON)
|
||||
progressBar.style.display = 'none'
|
||||
|
||||
if (res.status !== 'succeeded') {
|
||||
let msg = ''
|
||||
@ -680,7 +723,9 @@ async function doMakeImage(reqBody) {
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('request error', e)
|
||||
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
||||
setStatus('request', 'error', 'error')
|
||||
progressBar.style.display = 'none'
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
@ -789,7 +834,7 @@ async function makeImage() {
|
||||
|
||||
let validation = validateInput()
|
||||
if (validation['isValid']) {
|
||||
outputMsg.innerHTML = 'Fetching..'
|
||||
outputMsg.innerHTML = 'Starting..'
|
||||
} else {
|
||||
if (validation['error']) {
|
||||
logError(validation['error'])
|
||||
@ -806,6 +851,7 @@ async function makeImage() {
|
||||
stopImageBtn.style.display = 'block'
|
||||
|
||||
taskStopped = false
|
||||
batchesDone = 0
|
||||
|
||||
let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value))
|
||||
let numOutputsTotal = parseInt(numOutputsTotalField.value)
|
||||
@ -831,7 +877,8 @@ async function makeImage() {
|
||||
// allow_nsfw: allowNSFWField.checked,
|
||||
turbo: turboField.checked,
|
||||
use_cpu: useCPUField.checked,
|
||||
use_full_precision: useFullPrecisionField.checked
|
||||
use_full_precision: useFullPrecisionField.checked,
|
||||
stream_progress_updates: true
|
||||
}
|
||||
|
||||
if (IMAGE_REGEX.test(initImagePreview.src)) {
|
||||
@ -867,7 +914,8 @@ async function makeImage() {
|
||||
for (let i = 0; i < batchCount; i++) {
|
||||
reqBody['seed'] = seed + (i * batchSize)
|
||||
|
||||
let success = await doMakeImage(reqBody)
|
||||
let success = await doMakeImage(reqBody, batchCount)
|
||||
batchesDone++
|
||||
|
||||
if (success) {
|
||||
outputMsg.innerHTML = 'Processed batch ' + (i+1) + '/' + batchCount
|
||||
|
@ -21,6 +21,8 @@ class Request:
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
show_only_filtered_image: bool = False
|
||||
|
||||
stream_progress_updates: bool = False
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
"prompt": self.prompt,
|
||||
@ -50,7 +52,9 @@ class Request:
|
||||
use_full_precision: {self.use_full_precision}
|
||||
use_face_correction: {self.use_face_correction}
|
||||
use_upscale: {self.use_upscale}
|
||||
show_only_filtered_image: {self.show_only_filtered_image}'''
|
||||
show_only_filtered_image: {self.show_only_filtered_image}
|
||||
|
||||
stream_progress_updates: {self.stream_progress_updates}'''
|
||||
|
||||
class Image:
|
||||
data: str # base64
|
||||
|
@ -1,34 +1,121 @@
|
||||
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
||||
index dcf7901..1f99adc 100644
|
||||
index dcf7901..4028a70 100644
|
||||
--- a/optimizedSD/ddpm.py
|
||||
+++ b/optimizedSD/ddpm.py
|
||||
@@ -528,7 +528,8 @@ class UNet(DDPM):
|
||||
@@ -485,6 +485,7 @@ class UNet(DDPM):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
+ streaming_callbacks = False,
|
||||
):
|
||||
|
||||
|
||||
@@ -523,12 +524,15 @@ class UNet(DDPM):
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
+ streaming_callbacks=streaming_callbacks
|
||||
)
|
||||
|
||||
elif sampler == "ddim":
|
||||
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
- mask = mask,init_latent=x_T,use_original_steps=False)
|
||||
+ mask = mask,init_latent=x_T,use_original_steps=False,
|
||||
+ callback=callback, img_callback=img_callback)
|
||||
+ callback=callback, img_callback=img_callback,
|
||||
+ streaming_callbacks=streaming_callbacks)
|
||||
|
||||
# elif sampler == "euler":
|
||||
# cvd = CompVisDenoiser(self.alphas_cumprod)
|
||||
@@ -687,7 +688,8 @@ class UNet(DDPM):
|
||||
@@ -536,11 +540,15 @@ class UNet(DDPM):
|
||||
# samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||
# unconditional_guidance_scale=unconditional_guidance_scale)
|
||||
|
||||
+ if streaming_callbacks: # this line needs to be right after the sampling() call
|
||||
+ yield from samples
|
||||
+
|
||||
if(self.turbo):
|
||||
self.model1.to("cpu")
|
||||
self.model2.to("cpu")
|
||||
|
||||
- return samples
|
||||
+ if not streaming_callbacks:
|
||||
+ return samples
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond,b, img,
|
||||
@@ -548,7 +556,8 @@ class UNet(DDPM):
|
||||
callback=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
+ streaming_callbacks=False):
|
||||
|
||||
device = self.betas.device
|
||||
timesteps = self.ddim_timesteps
|
||||
@@ -580,10 +589,22 @@ class UNet(DDPM):
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
- if callback: callback(i)
|
||||
- if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
- return img
|
||||
+ if callback:
|
||||
+ if streaming_callbacks:
|
||||
+ yield from callback(i)
|
||||
+ else:
|
||||
+ callback(i)
|
||||
+ if img_callback:
|
||||
+ if streaming_callbacks:
|
||||
+ yield from img_callback(pred_x0, i)
|
||||
+ else:
|
||||
+ img_callback(pred_x0, i)
|
||||
+
|
||||
+ if streaming_callbacks and img_callback:
|
||||
+ yield from img_callback(img, len(iterator)-1)
|
||||
+ else:
|
||||
+ return img
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
@@ -687,7 +708,9 @@ class UNet(DDPM):
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
- mask = None,init_latent=None,use_original_steps=False):
|
||||
+ mask = None,init_latent=None,use_original_steps=False,
|
||||
+ callback=None, img_callback=None):
|
||||
+ callback=None, img_callback=None,
|
||||
+ streaming_callbacks=False):
|
||||
|
||||
timesteps = self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
@@ -710,6 +712,9 @@ class UNet(DDPM):
|
||||
@@ -710,11 +733,25 @@ class UNet(DDPM):
|
||||
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
+
|
||||
+ if callback: callback(i)
|
||||
+ if img_callback: img_callback(x_dec, i)
|
||||
+ if callback:
|
||||
+ if streaming_callbacks:
|
||||
+ yield from callback(i)
|
||||
+ else:
|
||||
+ callback(i)
|
||||
+ if img_callback:
|
||||
+ if streaming_callbacks:
|
||||
+ yield from img_callback(x_dec, i)
|
||||
+ else:
|
||||
+ img_callback(x_dec, i)
|
||||
|
||||
if mask is not None:
|
||||
return x0 * mask + (1. - mask) * x_dec
|
||||
- return x0 * mask + (1. - mask) * x_dec
|
||||
+ x_dec = x0 * mask + (1. - mask) * x_dec
|
||||
|
||||
- return x_dec
|
||||
+ if streaming_callbacks and img_callback:
|
||||
+ yield from img_callback(x_dec, len(iterator)-1)
|
||||
+ else:
|
||||
+ return x_dec
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os, re
|
||||
import traceback
|
||||
import torch
|
||||
@ -332,15 +333,23 @@ def mk_img(req: Request):
|
||||
|
||||
partial_x_samples = x_samples
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
|
||||
|
||||
if stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
# run the handler
|
||||
try:
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
|
||||
x_samples = partial_x_samples
|
||||
except UserInitiatedStop:
|
||||
if partial_x_samples is None:
|
||||
continue
|
||||
@ -421,7 +430,10 @@ def mk_img(req: Request):
|
||||
del x_samples
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
|
||||
return res
|
||||
if req.stream_progress_updates:
|
||||
yield json.dumps(res.json())
|
||||
else:
|
||||
return res
|
||||
|
||||
def save_image(img, img_out_path):
|
||||
try:
|
||||
@ -438,7 +450,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
@ -458,12 +470,16 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
eta=opt_ddim_eta,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
sampler = 'plms',
|
||||
)
|
||||
|
||||
return samples_ddim
|
||||
if streaming_callbacks:
|
||||
yield from samples_ddim
|
||||
else:
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
@ -480,10 +496,14 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
unconditional_guidance_scale=opt_scale,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
return samples_ddim
|
||||
if streaming_callbacks:
|
||||
yield from samples_ddim
|
||||
else:
|
||||
return samples_ddim
|
||||
|
||||
def gc():
|
||||
if device == 'cpu':
|
||||
|
15
ui/server.py
15
ui/server.py
@ -15,7 +15,7 @@ CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from starlette.responses import FileResponse
|
||||
from starlette.responses import FileResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
@ -50,6 +50,8 @@ class ImageRequest(BaseModel):
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
show_only_filtered_image: bool = False
|
||||
|
||||
stream_progress_updates: bool = False
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = "main"
|
||||
|
||||
@ -106,10 +108,15 @@ def image(req : ImageRequest):
|
||||
r.use_face_correction = req.use_face_correction
|
||||
r.show_only_filtered_image = req.show_only_filtered_image
|
||||
|
||||
try:
|
||||
res: Response = runtime.mk_img(r)
|
||||
r.stream_progress_updates = req.stream_progress_updates
|
||||
|
||||
return res.json()
|
||||
try:
|
||||
res = runtime.mk_img(r)
|
||||
|
||||
if r.stream_progress_updates:
|
||||
return StreamingResponse(res, media_type='application/json')
|
||||
else:
|
||||
return res.json()
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
Loading…
Reference in New Issue
Block a user