From 27071cfa29f89c7cff1563e0dedfec556b9bc60d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Sep 2022 22:29:42 +0530 Subject: [PATCH] Live preview of images --- ui/index.html | 67 +++++++++++++++++++++++++++++++++----- ui/sd_internal/__init__.py | 9 +++-- ui/sd_internal/runtime.py | 33 +++++++++++++++---- ui/server.py | 11 +++++++ 4 files changed, 102 insertions(+), 18 deletions(-) diff --git a/ui/index.html b/ui/index.html index 08db3599..f165090b 100644 --- a/ui/index.html +++ b/ui/index.html @@ -285,7 +285,7 @@
 
Stable Diffusion is starting.. -

Stable Diffusion UI v2.1

+

Stable Diffusion UI v2.11

@@ -316,6 +316,7 @@

Advanced Settings

    +
  • @@ -432,11 +433,14 @@ const MODIFIERS_PANEL_OPEN_KEY = "modifiersPanelOpen" const USE_FACE_CORRECTION_KEY = "useFaceCorrection" const USE_UPSCALING_KEY = "useUpscaling" const SHOW_ONLY_FILTERED_IMAGE_KEY = "showOnlyFilteredImage" +const STREAM_IMAGE_PROGRESS_KEY = "streamImageProgress" const HEALTH_PING_INTERVAL = 5 // seconds const MAX_INIT_IMAGE_DIMENSION = 768 const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64') +let sessionId = new Date().getTime() + let promptField = document.querySelector('#prompt') let numOutputsTotalField = document.querySelector('#num_outputs_total') let numOutputsParallelField = document.querySelector('#num_outputs_parallel') @@ -465,6 +469,7 @@ let useUpscalingField = document.querySelector("#use_upscale") let upscaleModelField = document.querySelector("#upscale_model") let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image") let updateBranchLabel = document.querySelector("#updateBranchLabel") +let streamImageProgressField = document.querySelector("#stream_image_progress") let makeImageBtn = document.querySelector('#makeImage') let stopImageBtn = document.querySelector('#stopImage') @@ -577,6 +582,10 @@ function isModifiersPanelOpenEnabled() { return getLocalStorageBoolItem(MODIFIERS_PANEL_OPEN_KEY, false) } +function isStreamImageProgressEnabled() { + return getLocalStorageBoolItem(STREAM_IMAGE_PROGRESS_KEY, true) +} + function setStatus(statusType, msg, msgType) { if (statusType !== 'server') { return; @@ -636,6 +645,20 @@ async function healthCheck() { } } +function makeImageElement(width, height) { + let imgItem = document.createElement('div') + imgItem.className = 'imgItem' + + let img = document.createElement('img') + img.width = parseInt(width) + img.height = parseInt(height) + + imgItem.appendChild(img) + imagesContainer.appendChild(imgItem) + + return imgItem +} + // makes a single image. don't call this directly, use makeImage() instead async function doMakeImage(reqBody, batchCount) { if (taskStopped) { @@ -644,6 +667,17 @@ async function doMakeImage(reqBody, batchCount) { let res = '' let seed = reqBody['seed'] + let numOutputs = parseInt(reqBody['num_outputs']) + + let images = [] + + function makeImageContainers() { + if (images.length === 0) { + for (let i = 0; i < numOutputs; i++) { + images.push(makeImageElement(reqBody.width, reqBody.height)) + } + } + } try { res = await fetch('/image', { @@ -693,6 +727,17 @@ async function doMakeImage(reqBody, batchCount) { progressBar.innerHTML += `
    Time remaining (approx): ${millisecondsToStr(timeRemaining)}` } progressBar.style.display = 'block' + + if (stepUpdate.output !== undefined) { + makeImageContainers() + + for (idx in stepUpdate.output) { + let imgItem = images[idx] + let img = imgItem.firstChild + let tmpImageData = stepUpdate.output[idx] + img.src = tmpImageData['path'] + '?t=' + new Date().getTime() + } + } } } catch (e) { finalJSON += jsonStr @@ -751,6 +796,8 @@ async function doMakeImage(reqBody, batchCount) { lastPromptUsed = reqBody['prompt'] + makeImageContainers() + for (let idx in res.output) { let imgBody = '' let seed = 0 @@ -765,12 +812,9 @@ async function doMakeImage(reqBody, batchCount) { continue } - let imgItem = document.createElement('div') - imgItem.className = 'imgItem' + let imgItem = images[idx] + let img = imgItem.firstChild - let img = document.createElement('img') - img.width = parseInt(reqBody.width) - img.height = parseInt(reqBody.height) img.src = imgBody let imgItemInfo = document.createElement('span') @@ -788,12 +832,10 @@ async function doMakeImage(reqBody, batchCount) { imgSaveBtn.className = 'imgSaveBtn' imgSaveBtn.innerHTML = 'Download' - imgItem.appendChild(img) imgItem.appendChild(imgItemInfo) imgItemInfo.appendChild(imgSeedLabel) imgItemInfo.appendChild(imgUseBtn) imgItemInfo.appendChild(imgSaveBtn) - imagesContainer.appendChild(imgItem) imgUseBtn.addEventListener('click', function() { initImageSelector.value = null @@ -876,6 +918,8 @@ async function makeImage() { let batchCount = Math.ceil(numOutputsTotal / numOutputsParallel) let batchSize = numOutputsParallel + let streamImageProgress = (numOutputsTotal > 50 ? false : streamImageProgressField.checked) + let prompt = promptField.value if (activeTags.length > 0) { let promptTags = activeTags.join(", ") @@ -885,6 +929,7 @@ async function makeImage() { previewPrompt.innerHTML = prompt let reqBody = { + session_id: sessionId, prompt: prompt, num_outputs: batchSize, num_inference_steps: numInferenceStepsField.value, @@ -895,7 +940,8 @@ async function makeImage() { turbo: turboField.checked, use_cpu: useCPUField.checked, use_full_precision: useFullPrecisionField.checked, - stream_progress_updates: true + stream_progress_updates: true, + stream_image_progress: streamImageProgress } if (IMAGE_REGEX.test(initImagePreview.src)) { @@ -1036,6 +1082,9 @@ useFullPrecisionField.checked = isUseFullPrecisionEnabled() turboField.addEventListener('click', handleBoolSettingChange(USE_TURBO_MODE_KEY)) turboField.checked = isUseTurboModeEnabled() +streamImageProgressField.addEventListener('click', handleBoolSettingChange(STREAM_IMAGE_PROGRESS_KEY)) +streamImageProgressField.checked = isStreamImageProgressEnabled() + diskPathField.addEventListener('change', handleStringSettingChange(DISK_PATH_KEY)) saveToDiskField.addEventListener('click', function(e) { diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index d6173823..c0b6c6dc 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,6 +1,7 @@ import json class Request: + session_id: str = "session" prompt: str = "" init_image: str = None # base64 mask: str = None # base64 @@ -22,9 +23,11 @@ class Request: show_only_filtered_image: bool = False stream_progress_updates: bool = False + stream_image_progress: bool = False def json(self): return { + "session_id": self.session_id, "prompt": self.prompt, "num_outputs": self.num_outputs, "num_inference_steps": self.num_inference_steps, @@ -39,6 +42,7 @@ class Request: def to_string(self): return f''' + session_id: {self.session_id} prompt: {self.prompt} seed: {self.seed} num_inference_steps: {self.num_inference_steps} @@ -54,7 +58,8 @@ class Request: use_upscale: {self.use_upscale} show_only_filtered_image: {self.show_only_filtered_image} - stream_progress_updates: {self.stream_progress_updates}''' + stream_progress_updates: {self.stream_progress_updates} + stream_image_progress: {self.stream_image_progress}''' class Image: data: str # base64 @@ -75,13 +80,11 @@ class Image: class Response: request: Request - session_id: str images: list def json(self): res = { "status": 'succeeded', - "session_id": self.session_id, "request": self.request.json(), "output": [], } diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index ca52cc64..613cb967 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -35,8 +35,8 @@ import base64 from io import BytesIO # local -session_id = str(uuid.uuid4())[-8:] stop_processing = False +temp_images = {} ckpt_file = None gfpgan_file = None @@ -192,10 +192,11 @@ def mk_img(req: Request): stop_processing = False res = Response() - res.session_id = session_id res.request = req res.images = [] + temp_images.clear() + model.turbo = req.turbo if req.use_cpu: if device != 'cpu': @@ -296,7 +297,7 @@ def mk_img(req: Request): print(f"target t_enc is {t_enc} steps") if opt_save_to_disk_path is not None: - session_out_path = os.path.join(opt_save_to_disk_path, session_id) + session_out_path = os.path.join(opt_save_to_disk_path, req.session_id) os.makedirs(session_out_path, exist_ok=True) else: session_out_path = None @@ -327,6 +328,8 @@ def mk_img(req: Request): else: c = modelCS.get_learned_conditioning(prompts) + modelFS.to(device) + partial_x_samples = None def img_callback(x_samples, i): nonlocal partial_x_samples @@ -334,7 +337,27 @@ def mk_img(req: Request): partial_x_samples = x_samples if req.stream_progress_updates: - yield json.dumps({"step": i, "total_steps": opt_ddim_steps}) + progress = {"step": i, "total_steps": opt_ddim_steps} + + if req.stream_image_progress: + partial_images = [] + + for i in range(batch_size): + x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) + x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") + x_sample = x_sample.astype(np.uint8) + img = Image.fromarray(x_sample) + buf = BytesIO() + img.save(buf, format='JPEG') + buf.seek(0) + + temp_images[str(req.session_id) + '/' + str(i)] = buf + partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) + + progress['output'] = partial_images + + yield json.dumps(progress) if stop_processing: raise UserInitiatedStop("User requested that we stop processing") @@ -356,8 +379,6 @@ def mk_img(req: Request): x_samples = partial_x_samples - modelFS.to(device) - print("saving images") for i in range(batch_size): diff --git a/ui/server.py b/ui/server.py index c7c2766c..3a79317b 100644 --- a/ui/server.py +++ b/ui/server.py @@ -31,6 +31,7 @@ outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) # defaults from https://huggingface.co/blog/stable_diffusion class ImageRequest(BaseModel): + session_id: str = "session" prompt: str = "" init_image: str = None # base64 mask: str = None # base64 @@ -51,6 +52,7 @@ class ImageRequest(BaseModel): show_only_filtered_image: bool = False stream_progress_updates: bool = False + stream_image_progress: bool = False class SetAppConfigRequest(BaseModel): update_branch: str = "main" @@ -89,6 +91,7 @@ def image(req : ImageRequest): from sd_internal import runtime r = Request() + r.session_id = req.session_id r.prompt = req.prompt r.init_image = req.init_image r.mask = req.mask @@ -109,6 +112,7 @@ def image(req : ImageRequest): r.show_only_filtered_image = req.show_only_filtered_image r.stream_progress_updates = req.stream_progress_updates + r.stream_image_progress = req.stream_image_progress try: res = runtime.mk_img(r) @@ -135,6 +139,13 @@ def stop(): print(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e)) +@app.get('/image/tmp/{session_id}/{img_id}') +def get_image(session_id, img_id): + from sd_internal import runtime + buf = runtime.temp_images[session_id + '/' + img_id] + buf.seek(0) + return StreamingResponse(buf, media_type='image/jpeg') + @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): try: