Live preview of images

This commit is contained in:
cmdr2 2022-09-14 22:29:42 +05:30
parent 1d88a5b42e
commit 27071cfa29
4 changed files with 102 additions and 18 deletions

View File

@ -285,7 +285,7 @@
<div id="server-status-color">&nbsp;</div>
<span id="server-status-msg">Stable Diffusion is starting..</span>
</div>
<h1>Stable Diffusion UI <small>v2.1 <span id="updateBranchLabel"></span></small></h1>
<h1>Stable Diffusion UI <small>v2.11 <span id="updateBranchLabel"></span></small></h1>
</div>
<div id="editor-inputs">
<div id="editor-inputs-prompt" class="row">
@ -316,6 +316,7 @@
<div id="editor-settings" class="panel-box">
<h4 class="collapsible">Advanced Settings</h4>
<ul id="editor-settings-entries" class="collapsible-content">
<li><input id="stream_image_progress" name="stream_image_progress" type="checkbox" checked> <label for="stream_image_progress">Show a live preview of the image (disable this for faster image generation)</label></li>
<li><input id="use_face_correction" name="use_face_correction" type="checkbox" checked> <label for="use_face_correction">Fix incorrect faces and eyes (uses GFPGAN)</label></li>
<li>
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale the image to 4x resolution using </label>
@ -432,11 +433,14 @@ const MODIFIERS_PANEL_OPEN_KEY = "modifiersPanelOpen"
const USE_FACE_CORRECTION_KEY = "useFaceCorrection"
const USE_UPSCALING_KEY = "useUpscaling"
const SHOW_ONLY_FILTERED_IMAGE_KEY = "showOnlyFilteredImage"
const STREAM_IMAGE_PROGRESS_KEY = "streamImageProgress"
const HEALTH_PING_INTERVAL = 5 // seconds
const MAX_INIT_IMAGE_DIMENSION = 768
const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64')
let sessionId = new Date().getTime()
let promptField = document.querySelector('#prompt')
let numOutputsTotalField = document.querySelector('#num_outputs_total')
let numOutputsParallelField = document.querySelector('#num_outputs_parallel')
@ -465,6 +469,7 @@ let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model")
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel")
let streamImageProgressField = document.querySelector("#stream_image_progress")
let makeImageBtn = document.querySelector('#makeImage')
let stopImageBtn = document.querySelector('#stopImage')
@ -577,6 +582,10 @@ function isModifiersPanelOpenEnabled() {
return getLocalStorageBoolItem(MODIFIERS_PANEL_OPEN_KEY, false)
}
function isStreamImageProgressEnabled() {
return getLocalStorageBoolItem(STREAM_IMAGE_PROGRESS_KEY, true)
}
function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') {
return;
@ -636,6 +645,20 @@ async function healthCheck() {
}
}
function makeImageElement(width, height) {
let imgItem = document.createElement('div')
imgItem.className = 'imgItem'
let img = document.createElement('img')
img.width = parseInt(width)
img.height = parseInt(height)
imgItem.appendChild(img)
imagesContainer.appendChild(imgItem)
return imgItem
}
// makes a single image. don't call this directly, use makeImage() instead
async function doMakeImage(reqBody, batchCount) {
if (taskStopped) {
@ -644,6 +667,17 @@ async function doMakeImage(reqBody, batchCount) {
let res = ''
let seed = reqBody['seed']
let numOutputs = parseInt(reqBody['num_outputs'])
let images = []
function makeImageContainers() {
if (images.length === 0) {
for (let i = 0; i < numOutputs; i++) {
images.push(makeImageElement(reqBody.width, reqBody.height))
}
}
}
try {
res = await fetch('/image', {
@ -693,6 +727,17 @@ async function doMakeImage(reqBody, batchCount) {
progressBar.innerHTML += `<br>Time remaining (approx): ${millisecondsToStr(timeRemaining)}`
}
progressBar.style.display = 'block'
if (stepUpdate.output !== undefined) {
makeImageContainers()
for (idx in stepUpdate.output) {
let imgItem = images[idx]
let img = imgItem.firstChild
let tmpImageData = stepUpdate.output[idx]
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
}
}
}
} catch (e) {
finalJSON += jsonStr
@ -751,6 +796,8 @@ async function doMakeImage(reqBody, batchCount) {
lastPromptUsed = reqBody['prompt']
makeImageContainers()
for (let idx in res.output) {
let imgBody = ''
let seed = 0
@ -765,12 +812,9 @@ async function doMakeImage(reqBody, batchCount) {
continue
}
let imgItem = document.createElement('div')
imgItem.className = 'imgItem'
let imgItem = images[idx]
let img = imgItem.firstChild
let img = document.createElement('img')
img.width = parseInt(reqBody.width)
img.height = parseInt(reqBody.height)
img.src = imgBody
let imgItemInfo = document.createElement('span')
@ -788,12 +832,10 @@ async function doMakeImage(reqBody, batchCount) {
imgSaveBtn.className = 'imgSaveBtn'
imgSaveBtn.innerHTML = 'Download'
imgItem.appendChild(img)
imgItem.appendChild(imgItemInfo)
imgItemInfo.appendChild(imgSeedLabel)
imgItemInfo.appendChild(imgUseBtn)
imgItemInfo.appendChild(imgSaveBtn)
imagesContainer.appendChild(imgItem)
imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null
@ -876,6 +918,8 @@ async function makeImage() {
let batchCount = Math.ceil(numOutputsTotal / numOutputsParallel)
let batchSize = numOutputsParallel
let streamImageProgress = (numOutputsTotal > 50 ? false : streamImageProgressField.checked)
let prompt = promptField.value
if (activeTags.length > 0) {
let promptTags = activeTags.join(", ")
@ -885,6 +929,7 @@ async function makeImage() {
previewPrompt.innerHTML = prompt
let reqBody = {
session_id: sessionId,
prompt: prompt,
num_outputs: batchSize,
num_inference_steps: numInferenceStepsField.value,
@ -895,7 +940,8 @@ async function makeImage() {
turbo: turboField.checked,
use_cpu: useCPUField.checked,
use_full_precision: useFullPrecisionField.checked,
stream_progress_updates: true
stream_progress_updates: true,
stream_image_progress: streamImageProgress
}
if (IMAGE_REGEX.test(initImagePreview.src)) {
@ -1036,6 +1082,9 @@ useFullPrecisionField.checked = isUseFullPrecisionEnabled()
turboField.addEventListener('click', handleBoolSettingChange(USE_TURBO_MODE_KEY))
turboField.checked = isUseTurboModeEnabled()
streamImageProgressField.addEventListener('click', handleBoolSettingChange(STREAM_IMAGE_PROGRESS_KEY))
streamImageProgressField.checked = isStreamImageProgressEnabled()
diskPathField.addEventListener('change', handleStringSettingChange(DISK_PATH_KEY))
saveToDiskField.addEventListener('click', function(e) {

View File

@ -1,6 +1,7 @@
import json
class Request:
session_id: str = "session"
prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
@ -22,9 +23,11 @@ class Request:
show_only_filtered_image: bool = False
stream_progress_updates: bool = False
stream_image_progress: bool = False
def json(self):
return {
"session_id": self.session_id,
"prompt": self.prompt,
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
@ -39,6 +42,7 @@ class Request:
def to_string(self):
return f'''
session_id: {self.session_id}
prompt: {self.prompt}
seed: {self.seed}
num_inference_steps: {self.num_inference_steps}
@ -54,7 +58,8 @@ class Request:
use_upscale: {self.use_upscale}
show_only_filtered_image: {self.show_only_filtered_image}
stream_progress_updates: {self.stream_progress_updates}'''
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''
class Image:
data: str # base64
@ -75,13 +80,11 @@ class Image:
class Response:
request: Request
session_id: str
images: list
def json(self):
res = {
"status": 'succeeded',
"session_id": self.session_id,
"request": self.request.json(),
"output": [],
}

View File

@ -35,8 +35,8 @@ import base64
from io import BytesIO
# local
session_id = str(uuid.uuid4())[-8:]
stop_processing = False
temp_images = {}
ckpt_file = None
gfpgan_file = None
@ -192,10 +192,11 @@ def mk_img(req: Request):
stop_processing = False
res = Response()
res.session_id = session_id
res.request = req
res.images = []
temp_images.clear()
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
@ -296,7 +297,7 @@ def mk_img(req: Request):
print(f"target t_enc is {t_enc} steps")
if opt_save_to_disk_path is not None:
session_out_path = os.path.join(opt_save_to_disk_path, session_id)
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True)
else:
session_out_path = None
@ -327,6 +328,8 @@ def mk_img(req: Request):
else:
c = modelCS.get_learned_conditioning(prompts)
modelFS.to(device)
partial_x_samples = None
def img_callback(x_samples, i):
nonlocal partial_x_samples
@ -334,7 +337,27 @@ def mk_img(req: Request):
partial_x_samples = x_samples
if req.stream_progress_updates:
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
progress = {"step": i, "total_steps": opt_ddim_steps}
if req.stream_image_progress:
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
@ -356,8 +379,6 @@ def mk_img(req: Request):
x_samples = partial_x_samples
modelFS.to(device)
print("saving images")
for i in range(batch_size):

View File

@ -31,6 +31,7 @@ outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
@ -51,6 +52,7 @@ class ImageRequest(BaseModel):
show_only_filtered_image: bool = False
stream_progress_updates: bool = False
stream_image_progress: bool = False
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
@ -89,6 +91,7 @@ def image(req : ImageRequest):
from sd_internal import runtime
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.init_image = req.init_image
r.mask = req.mask
@ -109,6 +112,7 @@ def image(req : ImageRequest):
r.show_only_filtered_image = req.show_only_filtered_image
r.stream_progress_updates = req.stream_progress_updates
r.stream_image_progress = req.stream_image_progress
try:
res = runtime.mk_img(r)
@ -135,6 +139,13 @@ def stop():
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/image/tmp/{session_id}/{img_id}')
def get_image(session_id, img_id):
from sd_internal import runtime
buf = runtime.temp_images[session_id + '/' + img_id]
buf.seek(0)
return StreamingResponse(buf, media_type='image/jpeg')
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try: