Live preview of images

This commit is contained in:
cmdr2 2022-09-14 22:29:42 +05:30
parent 1d88a5b42e
commit 27071cfa29
4 changed files with 102 additions and 18 deletions

View File

@ -285,7 +285,7 @@
<div id="server-status-color">&nbsp;</div> <div id="server-status-color">&nbsp;</div>
<span id="server-status-msg">Stable Diffusion is starting..</span> <span id="server-status-msg">Stable Diffusion is starting..</span>
</div> </div>
<h1>Stable Diffusion UI <small>v2.1 <span id="updateBranchLabel"></span></small></h1> <h1>Stable Diffusion UI <small>v2.11 <span id="updateBranchLabel"></span></small></h1>
</div> </div>
<div id="editor-inputs"> <div id="editor-inputs">
<div id="editor-inputs-prompt" class="row"> <div id="editor-inputs-prompt" class="row">
@ -316,6 +316,7 @@
<div id="editor-settings" class="panel-box"> <div id="editor-settings" class="panel-box">
<h4 class="collapsible">Advanced Settings</h4> <h4 class="collapsible">Advanced Settings</h4>
<ul id="editor-settings-entries" class="collapsible-content"> <ul id="editor-settings-entries" class="collapsible-content">
<li><input id="stream_image_progress" name="stream_image_progress" type="checkbox" checked> <label for="stream_image_progress">Show a live preview of the image (disable this for faster image generation)</label></li>
<li><input id="use_face_correction" name="use_face_correction" type="checkbox" checked> <label for="use_face_correction">Fix incorrect faces and eyes (uses GFPGAN)</label></li> <li><input id="use_face_correction" name="use_face_correction" type="checkbox" checked> <label for="use_face_correction">Fix incorrect faces and eyes (uses GFPGAN)</label></li>
<li> <li>
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale the image to 4x resolution using </label> <input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale the image to 4x resolution using </label>
@ -432,11 +433,14 @@ const MODIFIERS_PANEL_OPEN_KEY = "modifiersPanelOpen"
const USE_FACE_CORRECTION_KEY = "useFaceCorrection" const USE_FACE_CORRECTION_KEY = "useFaceCorrection"
const USE_UPSCALING_KEY = "useUpscaling" const USE_UPSCALING_KEY = "useUpscaling"
const SHOW_ONLY_FILTERED_IMAGE_KEY = "showOnlyFilteredImage" const SHOW_ONLY_FILTERED_IMAGE_KEY = "showOnlyFilteredImage"
const STREAM_IMAGE_PROGRESS_KEY = "streamImageProgress"
const HEALTH_PING_INTERVAL = 5 // seconds const HEALTH_PING_INTERVAL = 5 // seconds
const MAX_INIT_IMAGE_DIMENSION = 768 const MAX_INIT_IMAGE_DIMENSION = 768
const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64') const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64')
let sessionId = new Date().getTime()
let promptField = document.querySelector('#prompt') let promptField = document.querySelector('#prompt')
let numOutputsTotalField = document.querySelector('#num_outputs_total') let numOutputsTotalField = document.querySelector('#num_outputs_total')
let numOutputsParallelField = document.querySelector('#num_outputs_parallel') let numOutputsParallelField = document.querySelector('#num_outputs_parallel')
@ -465,6 +469,7 @@ let useUpscalingField = document.querySelector("#use_upscale")
let upscaleModelField = document.querySelector("#upscale_model") let upscaleModelField = document.querySelector("#upscale_model")
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image") let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
let updateBranchLabel = document.querySelector("#updateBranchLabel") let updateBranchLabel = document.querySelector("#updateBranchLabel")
let streamImageProgressField = document.querySelector("#stream_image_progress")
let makeImageBtn = document.querySelector('#makeImage') let makeImageBtn = document.querySelector('#makeImage')
let stopImageBtn = document.querySelector('#stopImage') let stopImageBtn = document.querySelector('#stopImage')
@ -577,6 +582,10 @@ function isModifiersPanelOpenEnabled() {
return getLocalStorageBoolItem(MODIFIERS_PANEL_OPEN_KEY, false) return getLocalStorageBoolItem(MODIFIERS_PANEL_OPEN_KEY, false)
} }
function isStreamImageProgressEnabled() {
return getLocalStorageBoolItem(STREAM_IMAGE_PROGRESS_KEY, true)
}
function setStatus(statusType, msg, msgType) { function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') { if (statusType !== 'server') {
return; return;
@ -636,6 +645,20 @@ async function healthCheck() {
} }
} }
function makeImageElement(width, height) {
let imgItem = document.createElement('div')
imgItem.className = 'imgItem'
let img = document.createElement('img')
img.width = parseInt(width)
img.height = parseInt(height)
imgItem.appendChild(img)
imagesContainer.appendChild(imgItem)
return imgItem
}
// makes a single image. don't call this directly, use makeImage() instead // makes a single image. don't call this directly, use makeImage() instead
async function doMakeImage(reqBody, batchCount) { async function doMakeImage(reqBody, batchCount) {
if (taskStopped) { if (taskStopped) {
@ -644,6 +667,17 @@ async function doMakeImage(reqBody, batchCount) {
let res = '' let res = ''
let seed = reqBody['seed'] let seed = reqBody['seed']
let numOutputs = parseInt(reqBody['num_outputs'])
let images = []
function makeImageContainers() {
if (images.length === 0) {
for (let i = 0; i < numOutputs; i++) {
images.push(makeImageElement(reqBody.width, reqBody.height))
}
}
}
try { try {
res = await fetch('/image', { res = await fetch('/image', {
@ -693,6 +727,17 @@ async function doMakeImage(reqBody, batchCount) {
progressBar.innerHTML += `<br>Time remaining (approx): ${millisecondsToStr(timeRemaining)}` progressBar.innerHTML += `<br>Time remaining (approx): ${millisecondsToStr(timeRemaining)}`
} }
progressBar.style.display = 'block' progressBar.style.display = 'block'
if (stepUpdate.output !== undefined) {
makeImageContainers()
for (idx in stepUpdate.output) {
let imgItem = images[idx]
let img = imgItem.firstChild
let tmpImageData = stepUpdate.output[idx]
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
}
}
} }
} catch (e) { } catch (e) {
finalJSON += jsonStr finalJSON += jsonStr
@ -751,6 +796,8 @@ async function doMakeImage(reqBody, batchCount) {
lastPromptUsed = reqBody['prompt'] lastPromptUsed = reqBody['prompt']
makeImageContainers()
for (let idx in res.output) { for (let idx in res.output) {
let imgBody = '' let imgBody = ''
let seed = 0 let seed = 0
@ -765,12 +812,9 @@ async function doMakeImage(reqBody, batchCount) {
continue continue
} }
let imgItem = document.createElement('div') let imgItem = images[idx]
imgItem.className = 'imgItem' let img = imgItem.firstChild
let img = document.createElement('img')
img.width = parseInt(reqBody.width)
img.height = parseInt(reqBody.height)
img.src = imgBody img.src = imgBody
let imgItemInfo = document.createElement('span') let imgItemInfo = document.createElement('span')
@ -788,12 +832,10 @@ async function doMakeImage(reqBody, batchCount) {
imgSaveBtn.className = 'imgSaveBtn' imgSaveBtn.className = 'imgSaveBtn'
imgSaveBtn.innerHTML = 'Download' imgSaveBtn.innerHTML = 'Download'
imgItem.appendChild(img)
imgItem.appendChild(imgItemInfo) imgItem.appendChild(imgItemInfo)
imgItemInfo.appendChild(imgSeedLabel) imgItemInfo.appendChild(imgSeedLabel)
imgItemInfo.appendChild(imgUseBtn) imgItemInfo.appendChild(imgUseBtn)
imgItemInfo.appendChild(imgSaveBtn) imgItemInfo.appendChild(imgSaveBtn)
imagesContainer.appendChild(imgItem)
imgUseBtn.addEventListener('click', function() { imgUseBtn.addEventListener('click', function() {
initImageSelector.value = null initImageSelector.value = null
@ -876,6 +918,8 @@ async function makeImage() {
let batchCount = Math.ceil(numOutputsTotal / numOutputsParallel) let batchCount = Math.ceil(numOutputsTotal / numOutputsParallel)
let batchSize = numOutputsParallel let batchSize = numOutputsParallel
let streamImageProgress = (numOutputsTotal > 50 ? false : streamImageProgressField.checked)
let prompt = promptField.value let prompt = promptField.value
if (activeTags.length > 0) { if (activeTags.length > 0) {
let promptTags = activeTags.join(", ") let promptTags = activeTags.join(", ")
@ -885,6 +929,7 @@ async function makeImage() {
previewPrompt.innerHTML = prompt previewPrompt.innerHTML = prompt
let reqBody = { let reqBody = {
session_id: sessionId,
prompt: prompt, prompt: prompt,
num_outputs: batchSize, num_outputs: batchSize,
num_inference_steps: numInferenceStepsField.value, num_inference_steps: numInferenceStepsField.value,
@ -895,7 +940,8 @@ async function makeImage() {
turbo: turboField.checked, turbo: turboField.checked,
use_cpu: useCPUField.checked, use_cpu: useCPUField.checked,
use_full_precision: useFullPrecisionField.checked, use_full_precision: useFullPrecisionField.checked,
stream_progress_updates: true stream_progress_updates: true,
stream_image_progress: streamImageProgress
} }
if (IMAGE_REGEX.test(initImagePreview.src)) { if (IMAGE_REGEX.test(initImagePreview.src)) {
@ -1036,6 +1082,9 @@ useFullPrecisionField.checked = isUseFullPrecisionEnabled()
turboField.addEventListener('click', handleBoolSettingChange(USE_TURBO_MODE_KEY)) turboField.addEventListener('click', handleBoolSettingChange(USE_TURBO_MODE_KEY))
turboField.checked = isUseTurboModeEnabled() turboField.checked = isUseTurboModeEnabled()
streamImageProgressField.addEventListener('click', handleBoolSettingChange(STREAM_IMAGE_PROGRESS_KEY))
streamImageProgressField.checked = isStreamImageProgressEnabled()
diskPathField.addEventListener('change', handleStringSettingChange(DISK_PATH_KEY)) diskPathField.addEventListener('change', handleStringSettingChange(DISK_PATH_KEY))
saveToDiskField.addEventListener('click', function(e) { saveToDiskField.addEventListener('click', function(e) {

View File

@ -1,6 +1,7 @@
import json import json
class Request: class Request:
session_id: str = "session"
prompt: str = "" prompt: str = ""
init_image: str = None # base64 init_image: str = None # base64
mask: str = None # base64 mask: str = None # base64
@ -22,9 +23,11 @@ class Request:
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False
def json(self): def json(self):
return { return {
"session_id": self.session_id,
"prompt": self.prompt, "prompt": self.prompt,
"num_outputs": self.num_outputs, "num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps, "num_inference_steps": self.num_inference_steps,
@ -39,6 +42,7 @@ class Request:
def to_string(self): def to_string(self):
return f''' return f'''
session_id: {self.session_id}
prompt: {self.prompt} prompt: {self.prompt}
seed: {self.seed} seed: {self.seed}
num_inference_steps: {self.num_inference_steps} num_inference_steps: {self.num_inference_steps}
@ -54,7 +58,8 @@ class Request:
use_upscale: {self.use_upscale} use_upscale: {self.use_upscale}
show_only_filtered_image: {self.show_only_filtered_image} show_only_filtered_image: {self.show_only_filtered_image}
stream_progress_updates: {self.stream_progress_updates}''' stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''
class Image: class Image:
data: str # base64 data: str # base64
@ -75,13 +80,11 @@ class Image:
class Response: class Response:
request: Request request: Request
session_id: str
images: list images: list
def json(self): def json(self):
res = { res = {
"status": 'succeeded', "status": 'succeeded',
"session_id": self.session_id,
"request": self.request.json(), "request": self.request.json(),
"output": [], "output": [],
} }

View File

@ -35,8 +35,8 @@ import base64
from io import BytesIO from io import BytesIO
# local # local
session_id = str(uuid.uuid4())[-8:]
stop_processing = False stop_processing = False
temp_images = {}
ckpt_file = None ckpt_file = None
gfpgan_file = None gfpgan_file = None
@ -192,10 +192,11 @@ def mk_img(req: Request):
stop_processing = False stop_processing = False
res = Response() res = Response()
res.session_id = session_id
res.request = req res.request = req
res.images = [] res.images = []
temp_images.clear()
model.turbo = req.turbo model.turbo = req.turbo
if req.use_cpu: if req.use_cpu:
if device != 'cpu': if device != 'cpu':
@ -296,7 +297,7 @@ def mk_img(req: Request):
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
if opt_save_to_disk_path is not None: if opt_save_to_disk_path is not None:
session_out_path = os.path.join(opt_save_to_disk_path, session_id) session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True) os.makedirs(session_out_path, exist_ok=True)
else: else:
session_out_path = None session_out_path = None
@ -327,6 +328,8 @@ def mk_img(req: Request):
else: else:
c = modelCS.get_learned_conditioning(prompts) c = modelCS.get_learned_conditioning(prompts)
modelFS.to(device)
partial_x_samples = None partial_x_samples = None
def img_callback(x_samples, i): def img_callback(x_samples, i):
nonlocal partial_x_samples nonlocal partial_x_samples
@ -334,7 +337,27 @@ def mk_img(req: Request):
partial_x_samples = x_samples partial_x_samples = x_samples
if req.stream_progress_updates: if req.stream_progress_updates:
yield json.dumps({"step": i, "total_steps": opt_ddim_steps}) progress = {"step": i, "total_steps": opt_ddim_steps}
if req.stream_image_progress:
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing: if stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
@ -356,8 +379,6 @@ def mk_img(req: Request):
x_samples = partial_x_samples x_samples = partial_x_samples
modelFS.to(device)
print("saving images") print("saving images")
for i in range(batch_size): for i in range(batch_size):

View File

@ -31,6 +31,7 @@ outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# defaults from https://huggingface.co/blog/stable_diffusion # defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel): class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = "" prompt: str = ""
init_image: str = None # base64 init_image: str = None # base64
mask: str = None # base64 mask: str = None # base64
@ -51,6 +52,7 @@ class ImageRequest(BaseModel):
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False
class SetAppConfigRequest(BaseModel): class SetAppConfigRequest(BaseModel):
update_branch: str = "main" update_branch: str = "main"
@ -89,6 +91,7 @@ def image(req : ImageRequest):
from sd_internal import runtime from sd_internal import runtime
r = Request() r = Request()
r.session_id = req.session_id
r.prompt = req.prompt r.prompt = req.prompt
r.init_image = req.init_image r.init_image = req.init_image
r.mask = req.mask r.mask = req.mask
@ -109,6 +112,7 @@ def image(req : ImageRequest):
r.show_only_filtered_image = req.show_only_filtered_image r.show_only_filtered_image = req.show_only_filtered_image
r.stream_progress_updates = req.stream_progress_updates r.stream_progress_updates = req.stream_progress_updates
r.stream_image_progress = req.stream_image_progress
try: try:
res = runtime.mk_img(r) res = runtime.mk_img(r)
@ -135,6 +139,13 @@ def stop():
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@app.get('/image/tmp/{session_id}/{img_id}')
def get_image(session_id, img_id):
from sd_internal import runtime
buf = runtime.temp_images[session_id + '/' + img_id]
buf.seek(0)
return StreamingResponse(buf, media_type='image/jpeg')
@app.post('/app_config') @app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
try: try: