mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-23 14:50:54 +01:00
Live preview of images
This commit is contained in:
parent
1d88a5b42e
commit
27071cfa29
@ -285,7 +285,7 @@
|
||||
<div id="server-status-color"> </div>
|
||||
<span id="server-status-msg">Stable Diffusion is starting..</span>
|
||||
</div>
|
||||
<h1>Stable Diffusion UI <small>v2.1 <span id="updateBranchLabel"></span></small></h1>
|
||||
<h1>Stable Diffusion UI <small>v2.11 <span id="updateBranchLabel"></span></small></h1>
|
||||
</div>
|
||||
<div id="editor-inputs">
|
||||
<div id="editor-inputs-prompt" class="row">
|
||||
@ -316,6 +316,7 @@
|
||||
<div id="editor-settings" class="panel-box">
|
||||
<h4 class="collapsible">Advanced Settings</h4>
|
||||
<ul id="editor-settings-entries" class="collapsible-content">
|
||||
<li><input id="stream_image_progress" name="stream_image_progress" type="checkbox" checked> <label for="stream_image_progress">Show a live preview of the image (disable this for faster image generation)</label></li>
|
||||
<li><input id="use_face_correction" name="use_face_correction" type="checkbox" checked> <label for="use_face_correction">Fix incorrect faces and eyes (uses GFPGAN)</label></li>
|
||||
<li>
|
||||
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale the image to 4x resolution using </label>
|
||||
@ -432,11 +433,14 @@ const MODIFIERS_PANEL_OPEN_KEY = "modifiersPanelOpen"
|
||||
const USE_FACE_CORRECTION_KEY = "useFaceCorrection"
|
||||
const USE_UPSCALING_KEY = "useUpscaling"
|
||||
const SHOW_ONLY_FILTERED_IMAGE_KEY = "showOnlyFilteredImage"
|
||||
const STREAM_IMAGE_PROGRESS_KEY = "streamImageProgress"
|
||||
const HEALTH_PING_INTERVAL = 5 // seconds
|
||||
const MAX_INIT_IMAGE_DIMENSION = 768
|
||||
|
||||
const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64')
|
||||
|
||||
let sessionId = new Date().getTime()
|
||||
|
||||
let promptField = document.querySelector('#prompt')
|
||||
let numOutputsTotalField = document.querySelector('#num_outputs_total')
|
||||
let numOutputsParallelField = document.querySelector('#num_outputs_parallel')
|
||||
@ -465,6 +469,7 @@ let useUpscalingField = document.querySelector("#use_upscale")
|
||||
let upscaleModelField = document.querySelector("#upscale_model")
|
||||
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
||||
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
||||
let streamImageProgressField = document.querySelector("#stream_image_progress")
|
||||
|
||||
let makeImageBtn = document.querySelector('#makeImage')
|
||||
let stopImageBtn = document.querySelector('#stopImage')
|
||||
@ -577,6 +582,10 @@ function isModifiersPanelOpenEnabled() {
|
||||
return getLocalStorageBoolItem(MODIFIERS_PANEL_OPEN_KEY, false)
|
||||
}
|
||||
|
||||
function isStreamImageProgressEnabled() {
|
||||
return getLocalStorageBoolItem(STREAM_IMAGE_PROGRESS_KEY, true)
|
||||
}
|
||||
|
||||
function setStatus(statusType, msg, msgType) {
|
||||
if (statusType !== 'server') {
|
||||
return;
|
||||
@ -636,6 +645,20 @@ async function healthCheck() {
|
||||
}
|
||||
}
|
||||
|
||||
function makeImageElement(width, height) {
|
||||
let imgItem = document.createElement('div')
|
||||
imgItem.className = 'imgItem'
|
||||
|
||||
let img = document.createElement('img')
|
||||
img.width = parseInt(width)
|
||||
img.height = parseInt(height)
|
||||
|
||||
imgItem.appendChild(img)
|
||||
imagesContainer.appendChild(imgItem)
|
||||
|
||||
return imgItem
|
||||
}
|
||||
|
||||
// makes a single image. don't call this directly, use makeImage() instead
|
||||
async function doMakeImage(reqBody, batchCount) {
|
||||
if (taskStopped) {
|
||||
@ -644,6 +667,17 @@ async function doMakeImage(reqBody, batchCount) {
|
||||
|
||||
let res = ''
|
||||
let seed = reqBody['seed']
|
||||
let numOutputs = parseInt(reqBody['num_outputs'])
|
||||
|
||||
let images = []
|
||||
|
||||
function makeImageContainers() {
|
||||
if (images.length === 0) {
|
||||
for (let i = 0; i < numOutputs; i++) {
|
||||
images.push(makeImageElement(reqBody.width, reqBody.height))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
res = await fetch('/image', {
|
||||
@ -693,6 +727,17 @@ async function doMakeImage(reqBody, batchCount) {
|
||||
progressBar.innerHTML += `<br>Time remaining (approx): ${millisecondsToStr(timeRemaining)}`
|
||||
}
|
||||
progressBar.style.display = 'block'
|
||||
|
||||
if (stepUpdate.output !== undefined) {
|
||||
makeImageContainers()
|
||||
|
||||
for (idx in stepUpdate.output) {
|
||||
let imgItem = images[idx]
|
||||
let img = imgItem.firstChild
|
||||
let tmpImageData = stepUpdate.output[idx]
|
||||
img.src = tmpImageData['path'] + '?t=' + new Date().getTime()
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
finalJSON += jsonStr
|
||||
@ -751,6 +796,8 @@ async function doMakeImage(reqBody, batchCount) {
|
||||
|
||||
lastPromptUsed = reqBody['prompt']
|
||||
|
||||
makeImageContainers()
|
||||
|
||||
for (let idx in res.output) {
|
||||
let imgBody = ''
|
||||
let seed = 0
|
||||
@ -765,12 +812,9 @@ async function doMakeImage(reqBody, batchCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
let imgItem = document.createElement('div')
|
||||
imgItem.className = 'imgItem'
|
||||
let imgItem = images[idx]
|
||||
let img = imgItem.firstChild
|
||||
|
||||
let img = document.createElement('img')
|
||||
img.width = parseInt(reqBody.width)
|
||||
img.height = parseInt(reqBody.height)
|
||||
img.src = imgBody
|
||||
|
||||
let imgItemInfo = document.createElement('span')
|
||||
@ -788,12 +832,10 @@ async function doMakeImage(reqBody, batchCount) {
|
||||
imgSaveBtn.className = 'imgSaveBtn'
|
||||
imgSaveBtn.innerHTML = 'Download'
|
||||
|
||||
imgItem.appendChild(img)
|
||||
imgItem.appendChild(imgItemInfo)
|
||||
imgItemInfo.appendChild(imgSeedLabel)
|
||||
imgItemInfo.appendChild(imgUseBtn)
|
||||
imgItemInfo.appendChild(imgSaveBtn)
|
||||
imagesContainer.appendChild(imgItem)
|
||||
|
||||
imgUseBtn.addEventListener('click', function() {
|
||||
initImageSelector.value = null
|
||||
@ -876,6 +918,8 @@ async function makeImage() {
|
||||
let batchCount = Math.ceil(numOutputsTotal / numOutputsParallel)
|
||||
let batchSize = numOutputsParallel
|
||||
|
||||
let streamImageProgress = (numOutputsTotal > 50 ? false : streamImageProgressField.checked)
|
||||
|
||||
let prompt = promptField.value
|
||||
if (activeTags.length > 0) {
|
||||
let promptTags = activeTags.join(", ")
|
||||
@ -885,6 +929,7 @@ async function makeImage() {
|
||||
previewPrompt.innerHTML = prompt
|
||||
|
||||
let reqBody = {
|
||||
session_id: sessionId,
|
||||
prompt: prompt,
|
||||
num_outputs: batchSize,
|
||||
num_inference_steps: numInferenceStepsField.value,
|
||||
@ -895,7 +940,8 @@ async function makeImage() {
|
||||
turbo: turboField.checked,
|
||||
use_cpu: useCPUField.checked,
|
||||
use_full_precision: useFullPrecisionField.checked,
|
||||
stream_progress_updates: true
|
||||
stream_progress_updates: true,
|
||||
stream_image_progress: streamImageProgress
|
||||
}
|
||||
|
||||
if (IMAGE_REGEX.test(initImagePreview.src)) {
|
||||
@ -1036,6 +1082,9 @@ useFullPrecisionField.checked = isUseFullPrecisionEnabled()
|
||||
turboField.addEventListener('click', handleBoolSettingChange(USE_TURBO_MODE_KEY))
|
||||
turboField.checked = isUseTurboModeEnabled()
|
||||
|
||||
streamImageProgressField.addEventListener('click', handleBoolSettingChange(STREAM_IMAGE_PROGRESS_KEY))
|
||||
streamImageProgressField.checked = isStreamImageProgressEnabled()
|
||||
|
||||
diskPathField.addEventListener('change', handleStringSettingChange(DISK_PATH_KEY))
|
||||
|
||||
saveToDiskField.addEventListener('click', function(e) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
class Request:
|
||||
session_id: str = "session"
|
||||
prompt: str = ""
|
||||
init_image: str = None # base64
|
||||
mask: str = None # base64
|
||||
@ -22,9 +23,11 @@ class Request:
|
||||
show_only_filtered_image: bool = False
|
||||
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"prompt": self.prompt,
|
||||
"num_outputs": self.num_outputs,
|
||||
"num_inference_steps": self.num_inference_steps,
|
||||
@ -39,6 +42,7 @@ class Request:
|
||||
|
||||
def to_string(self):
|
||||
return f'''
|
||||
session_id: {self.session_id}
|
||||
prompt: {self.prompt}
|
||||
seed: {self.seed}
|
||||
num_inference_steps: {self.num_inference_steps}
|
||||
@ -54,7 +58,8 @@ class Request:
|
||||
use_upscale: {self.use_upscale}
|
||||
show_only_filtered_image: {self.show_only_filtered_image}
|
||||
|
||||
stream_progress_updates: {self.stream_progress_updates}'''
|
||||
stream_progress_updates: {self.stream_progress_updates}
|
||||
stream_image_progress: {self.stream_image_progress}'''
|
||||
|
||||
class Image:
|
||||
data: str # base64
|
||||
@ -75,13 +80,11 @@ class Image:
|
||||
|
||||
class Response:
|
||||
request: Request
|
||||
session_id: str
|
||||
images: list
|
||||
|
||||
def json(self):
|
||||
res = {
|
||||
"status": 'succeeded',
|
||||
"session_id": self.session_id,
|
||||
"request": self.request.json(),
|
||||
"output": [],
|
||||
}
|
||||
|
@ -35,8 +35,8 @@ import base64
|
||||
from io import BytesIO
|
||||
|
||||
# local
|
||||
session_id = str(uuid.uuid4())[-8:]
|
||||
stop_processing = False
|
||||
temp_images = {}
|
||||
|
||||
ckpt_file = None
|
||||
gfpgan_file = None
|
||||
@ -192,10 +192,11 @@ def mk_img(req: Request):
|
||||
stop_processing = False
|
||||
|
||||
res = Response()
|
||||
res.session_id = session_id
|
||||
res.request = req
|
||||
res.images = []
|
||||
|
||||
temp_images.clear()
|
||||
|
||||
model.turbo = req.turbo
|
||||
if req.use_cpu:
|
||||
if device != 'cpu':
|
||||
@ -296,7 +297,7 @@ def mk_img(req: Request):
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
if opt_save_to_disk_path is not None:
|
||||
session_out_path = os.path.join(opt_save_to_disk_path, session_id)
|
||||
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
|
||||
os.makedirs(session_out_path, exist_ok=True)
|
||||
else:
|
||||
session_out_path = None
|
||||
@ -327,6 +328,8 @@ def mk_img(req: Request):
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
modelFS.to(device)
|
||||
|
||||
partial_x_samples = None
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal partial_x_samples
|
||||
@ -334,7 +337,27 @@ def mk_img(req: Request):
|
||||
partial_x_samples = x_samples
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
|
||||
progress = {"step": i, "total_steps": opt_ddim_steps}
|
||||
|
||||
if req.stream_image_progress:
|
||||
partial_images = []
|
||||
|
||||
for i in range(batch_size):
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='JPEG')
|
||||
buf.seek(0)
|
||||
|
||||
temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
|
||||
progress['output'] = partial_images
|
||||
|
||||
yield json.dumps(progress)
|
||||
|
||||
if stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
@ -356,8 +379,6 @@ def mk_img(req: Request):
|
||||
|
||||
x_samples = partial_x_samples
|
||||
|
||||
modelFS.to(device)
|
||||
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
|
11
ui/server.py
11
ui/server.py
@ -31,6 +31,7 @@ outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
||||
|
||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
||||
class ImageRequest(BaseModel):
|
||||
session_id: str = "session"
|
||||
prompt: str = ""
|
||||
init_image: str = None # base64
|
||||
mask: str = None # base64
|
||||
@ -51,6 +52,7 @@ class ImageRequest(BaseModel):
|
||||
show_only_filtered_image: bool = False
|
||||
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = "main"
|
||||
@ -89,6 +91,7 @@ def image(req : ImageRequest):
|
||||
from sd_internal import runtime
|
||||
|
||||
r = Request()
|
||||
r.session_id = req.session_id
|
||||
r.prompt = req.prompt
|
||||
r.init_image = req.init_image
|
||||
r.mask = req.mask
|
||||
@ -109,6 +112,7 @@ def image(req : ImageRequest):
|
||||
r.show_only_filtered_image = req.show_only_filtered_image
|
||||
|
||||
r.stream_progress_updates = req.stream_progress_updates
|
||||
r.stream_image_progress = req.stream_image_progress
|
||||
|
||||
try:
|
||||
res = runtime.mk_img(r)
|
||||
@ -135,6 +139,13 @@ def stop():
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/image/tmp/{session_id}/{img_id}')
|
||||
def get_image(session_id, img_id):
|
||||
from sd_internal import runtime
|
||||
buf = runtime.temp_images[session_id + '/' + img_id]
|
||||
buf.seek(0)
|
||||
return StreamingResponse(buf, media_type='image/jpeg')
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user