forked from extern/easydiffusion
Show the progress percentage while generating images
This commit is contained in:
parent
3b47eb3b07
commit
9f48d5e5ff
@ -60,6 +60,9 @@
|
|||||||
#outputMsg {
|
#outputMsg {
|
||||||
font-size: small;
|
font-size: small;
|
||||||
}
|
}
|
||||||
|
#progressBar {
|
||||||
|
font-size: small;
|
||||||
|
}
|
||||||
#footer {
|
#footer {
|
||||||
font-size: small;
|
font-size: small;
|
||||||
padding-left: 10pt;
|
padding-left: 10pt;
|
||||||
@ -397,6 +400,7 @@
|
|||||||
<div id="preview-prompt">Type a prompt and press the "Make Image" button.<br/><br/>You can set an "Initial Image" if you want to guide the AI.<br/><br/>You can also add modifiers like "Realistic", "Pencil Sketch", "ArtStation" etc by browsing through the "Image Modifiers" section and selecting the desired modifiers.<br/><br/>Click "Advanced Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)</div>
|
<div id="preview-prompt">Type a prompt and press the "Make Image" button.<br/><br/>You can set an "Initial Image" if you want to guide the AI.<br/><br/>You can also add modifiers like "Realistic", "Pencil Sketch", "ArtStation" etc by browsing through the "Image Modifiers" section and selecting the desired modifiers.<br/><br/>Click "Advanced Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)</div>
|
||||||
|
|
||||||
<div id="outputMsg"></div>
|
<div id="outputMsg"></div>
|
||||||
|
<div id="progressBar"></div>
|
||||||
<div id="current-images" class="img-preview">
|
<div id="current-images" class="img-preview">
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@ -483,6 +487,7 @@ let previewPrompt = document.querySelector('#preview-prompt')
|
|||||||
let showConfigToggle = document.querySelector('#configToggleBtn')
|
let showConfigToggle = document.querySelector('#configToggleBtn')
|
||||||
// let configBox = document.querySelector('#config')
|
// let configBox = document.querySelector('#config')
|
||||||
let outputMsg = document.querySelector('#outputMsg')
|
let outputMsg = document.querySelector('#outputMsg')
|
||||||
|
let progressBar = document.querySelector("#progressBar")
|
||||||
|
|
||||||
let soundToggle = document.querySelector('#sound_toggle')
|
let soundToggle = document.querySelector('#sound_toggle')
|
||||||
|
|
||||||
@ -496,6 +501,7 @@ let serverStatus = 'offline'
|
|||||||
let activeTags = []
|
let activeTags = []
|
||||||
let lastPromptUsed = ''
|
let lastPromptUsed = ''
|
||||||
let taskStopped = true
|
let taskStopped = true
|
||||||
|
let batchesDone = 0
|
||||||
|
|
||||||
function getLocalStorageItem(key, fallback) {
|
function getLocalStorageItem(key, fallback) {
|
||||||
let item = localStorage.getItem(key)
|
let item = localStorage.getItem(key)
|
||||||
@ -631,7 +637,7 @@ async function healthCheck() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// makes a single image. don't call this directly, use makeImage() instead
|
// makes a single image. don't call this directly, use makeImage() instead
|
||||||
async function doMakeImage(reqBody) {
|
async function doMakeImage(reqBody, batchCount) {
|
||||||
if (taskStopped) {
|
if (taskStopped) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -648,6 +654,41 @@ async function doMakeImage(reqBody) {
|
|||||||
body: JSON.stringify(reqBody)
|
body: JSON.stringify(reqBody)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
let reader = res.body.getReader()
|
||||||
|
let textDecoder = new TextDecoder()
|
||||||
|
let finalJSON = ''
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
const {value, done} = await reader.read()
|
||||||
|
if (done) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
let jsonStr = textDecoder.decode(value)
|
||||||
|
|
||||||
|
try {
|
||||||
|
let stepUpdate = JSON.parse(jsonStr)
|
||||||
|
|
||||||
|
if (stepUpdate.step !== undefined) {
|
||||||
|
let batchSize = parseInt(reqBody['num_inference_steps'])
|
||||||
|
let overallStepCount = stepUpdate.step + batchesDone * batchSize
|
||||||
|
let totalSteps = batchCount * batchSize
|
||||||
|
let percent = 100 * (overallStepCount / totalSteps)
|
||||||
|
percent = percent.toFixed(0)
|
||||||
|
outputMsg.innerHTML = `Batch ${batchesDone+1} of ${batchCount}`
|
||||||
|
progressBar.innerHTML = `Generating image: ${percent}%`
|
||||||
|
progressBar.style.display = 'block'
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
finalJSON += jsonStr
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
||||||
|
res = undefined
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (res.status != 200) {
|
if (res.status != 200) {
|
||||||
if (serverStatus === 'online') {
|
if (serverStatus === 'online') {
|
||||||
logError('Stable Diffusion had an error: ' + await res.text() + '. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
logError('Stable Diffusion had an error: ' + await res.text() + '. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
||||||
@ -655,8 +696,10 @@ async function doMakeImage(reqBody) {
|
|||||||
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed.", res)
|
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed.", res)
|
||||||
}
|
}
|
||||||
res = undefined
|
res = undefined
|
||||||
|
progressBar.style.display = 'none'
|
||||||
} else {
|
} else {
|
||||||
res = await res.json()
|
res = JSON.parse(finalJSON)
|
||||||
|
progressBar.style.display = 'none'
|
||||||
|
|
||||||
if (res.status !== 'succeeded') {
|
if (res.status !== 'succeeded') {
|
||||||
let msg = ''
|
let msg = ''
|
||||||
@ -680,7 +723,9 @@ async function doMakeImage(reqBody) {
|
|||||||
}
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('request error', e)
|
console.log('request error', e)
|
||||||
|
logError('Stable Diffusion had an error. Please check the logs in the command-line window. This happens sometimes. Maybe modify the prompt or seed a little bit?', res)
|
||||||
setStatus('request', 'error', 'error')
|
setStatus('request', 'error', 'error')
|
||||||
|
progressBar.style.display = 'none'
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
@ -789,7 +834,7 @@ async function makeImage() {
|
|||||||
|
|
||||||
let validation = validateInput()
|
let validation = validateInput()
|
||||||
if (validation['isValid']) {
|
if (validation['isValid']) {
|
||||||
outputMsg.innerHTML = 'Fetching..'
|
outputMsg.innerHTML = 'Starting..'
|
||||||
} else {
|
} else {
|
||||||
if (validation['error']) {
|
if (validation['error']) {
|
||||||
logError(validation['error'])
|
logError(validation['error'])
|
||||||
@ -806,6 +851,7 @@ async function makeImage() {
|
|||||||
stopImageBtn.style.display = 'block'
|
stopImageBtn.style.display = 'block'
|
||||||
|
|
||||||
taskStopped = false
|
taskStopped = false
|
||||||
|
batchesDone = 0
|
||||||
|
|
||||||
let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value))
|
let seed = (randomSeedField.checked ? Math.floor(Math.random() * 10000000) : parseInt(seedField.value))
|
||||||
let numOutputsTotal = parseInt(numOutputsTotalField.value)
|
let numOutputsTotal = parseInt(numOutputsTotalField.value)
|
||||||
@ -831,7 +877,8 @@ async function makeImage() {
|
|||||||
// allow_nsfw: allowNSFWField.checked,
|
// allow_nsfw: allowNSFWField.checked,
|
||||||
turbo: turboField.checked,
|
turbo: turboField.checked,
|
||||||
use_cpu: useCPUField.checked,
|
use_cpu: useCPUField.checked,
|
||||||
use_full_precision: useFullPrecisionField.checked
|
use_full_precision: useFullPrecisionField.checked,
|
||||||
|
stream_progress_updates: true
|
||||||
}
|
}
|
||||||
|
|
||||||
if (IMAGE_REGEX.test(initImagePreview.src)) {
|
if (IMAGE_REGEX.test(initImagePreview.src)) {
|
||||||
@ -867,7 +914,8 @@ async function makeImage() {
|
|||||||
for (let i = 0; i < batchCount; i++) {
|
for (let i = 0; i < batchCount; i++) {
|
||||||
reqBody['seed'] = seed + (i * batchSize)
|
reqBody['seed'] = seed + (i * batchSize)
|
||||||
|
|
||||||
let success = await doMakeImage(reqBody)
|
let success = await doMakeImage(reqBody, batchCount)
|
||||||
|
batchesDone++
|
||||||
|
|
||||||
if (success) {
|
if (success) {
|
||||||
outputMsg.innerHTML = 'Processed batch ' + (i+1) + '/' + batchCount
|
outputMsg.innerHTML = 'Processed batch ' + (i+1) + '/' + batchCount
|
||||||
|
@ -21,6 +21,8 @@ class Request:
|
|||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
|
|
||||||
|
stream_progress_updates: bool = False
|
||||||
|
|
||||||
def json(self):
|
def json(self):
|
||||||
return {
|
return {
|
||||||
"prompt": self.prompt,
|
"prompt": self.prompt,
|
||||||
@ -50,7 +52,9 @@ class Request:
|
|||||||
use_full_precision: {self.use_full_precision}
|
use_full_precision: {self.use_full_precision}
|
||||||
use_face_correction: {self.use_face_correction}
|
use_face_correction: {self.use_face_correction}
|
||||||
use_upscale: {self.use_upscale}
|
use_upscale: {self.use_upscale}
|
||||||
show_only_filtered_image: {self.show_only_filtered_image}'''
|
show_only_filtered_image: {self.show_only_filtered_image}
|
||||||
|
|
||||||
|
stream_progress_updates: {self.stream_progress_updates}'''
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
data: str # base64
|
data: str # base64
|
||||||
|
@ -1,34 +1,121 @@
|
|||||||
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
||||||
index dcf7901..1f99adc 100644
|
index dcf7901..4028a70 100644
|
||||||
--- a/optimizedSD/ddpm.py
|
--- a/optimizedSD/ddpm.py
|
||||||
+++ b/optimizedSD/ddpm.py
|
+++ b/optimizedSD/ddpm.py
|
||||||
@@ -528,7 +528,8 @@ class UNet(DDPM):
|
@@ -485,6 +485,7 @@ class UNet(DDPM):
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
+ streaming_callbacks = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
|
||||||
|
@@ -523,12 +524,15 @@ class UNet(DDPM):
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
+ streaming_callbacks=streaming_callbacks
|
||||||
|
)
|
||||||
|
|
||||||
elif sampler == "ddim":
|
elif sampler == "ddim":
|
||||||
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
- mask = mask,init_latent=x_T,use_original_steps=False)
|
- mask = mask,init_latent=x_T,use_original_steps=False)
|
||||||
+ mask = mask,init_latent=x_T,use_original_steps=False,
|
+ mask = mask,init_latent=x_T,use_original_steps=False,
|
||||||
+ callback=callback, img_callback=img_callback)
|
+ callback=callback, img_callback=img_callback,
|
||||||
|
+ streaming_callbacks=streaming_callbacks)
|
||||||
|
|
||||||
# elif sampler == "euler":
|
# elif sampler == "euler":
|
||||||
# cvd = CompVisDenoiser(self.alphas_cumprod)
|
# cvd = CompVisDenoiser(self.alphas_cumprod)
|
||||||
@@ -687,7 +688,8 @@ class UNet(DDPM):
|
@@ -536,11 +540,15 @@ class UNet(DDPM):
|
||||||
|
# samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning,
|
||||||
|
# unconditional_guidance_scale=unconditional_guidance_scale)
|
||||||
|
|
||||||
|
+ if streaming_callbacks: # this line needs to be right after the sampling() call
|
||||||
|
+ yield from samples
|
||||||
|
+
|
||||||
|
if(self.turbo):
|
||||||
|
self.model1.to("cpu")
|
||||||
|
self.model2.to("cpu")
|
||||||
|
|
||||||
|
- return samples
|
||||||
|
+ if not streaming_callbacks:
|
||||||
|
+ return samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond,b, img,
|
||||||
|
@@ -548,7 +556,8 @@ class UNet(DDPM):
|
||||||
|
callback=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
|
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
+ streaming_callbacks=False):
|
||||||
|
|
||||||
|
device = self.betas.device
|
||||||
|
timesteps = self.ddim_timesteps
|
||||||
|
@@ -580,10 +589,22 @@ class UNet(DDPM):
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
- if callback: callback(i)
|
||||||
|
- if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
- return img
|
||||||
|
+ if callback:
|
||||||
|
+ if streaming_callbacks:
|
||||||
|
+ yield from callback(i)
|
||||||
|
+ else:
|
||||||
|
+ callback(i)
|
||||||
|
+ if img_callback:
|
||||||
|
+ if streaming_callbacks:
|
||||||
|
+ yield from img_callback(pred_x0, i)
|
||||||
|
+ else:
|
||||||
|
+ img_callback(pred_x0, i)
|
||||||
|
+
|
||||||
|
+ if streaming_callbacks and img_callback:
|
||||||
|
+ yield from img_callback(img, len(iterator)-1)
|
||||||
|
+ else:
|
||||||
|
+ return img
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
@@ -687,7 +708,9 @@ class UNet(DDPM):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
- mask = None,init_latent=None,use_original_steps=False):
|
- mask = None,init_latent=None,use_original_steps=False):
|
||||||
+ mask = None,init_latent=None,use_original_steps=False,
|
+ mask = None,init_latent=None,use_original_steps=False,
|
||||||
+ callback=None, img_callback=None):
|
+ callback=None, img_callback=None,
|
||||||
|
+ streaming_callbacks=False):
|
||||||
|
|
||||||
timesteps = self.ddim_timesteps
|
timesteps = self.ddim_timesteps
|
||||||
timesteps = timesteps[:t_start]
|
timesteps = timesteps[:t_start]
|
||||||
@@ -710,6 +712,9 @@ class UNet(DDPM):
|
@@ -710,11 +733,25 @@ class UNet(DDPM):
|
||||||
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
+
|
+
|
||||||
+ if callback: callback(i)
|
+ if callback:
|
||||||
+ if img_callback: img_callback(x_dec, i)
|
+ if streaming_callbacks:
|
||||||
|
+ yield from callback(i)
|
||||||
|
+ else:
|
||||||
|
+ callback(i)
|
||||||
|
+ if img_callback:
|
||||||
|
+ if streaming_callbacks:
|
||||||
|
+ yield from img_callback(x_dec, i)
|
||||||
|
+ else:
|
||||||
|
+ img_callback(x_dec, i)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
return x0 * mask + (1. - mask) * x_dec
|
- return x0 * mask + (1. - mask) * x_dec
|
||||||
|
+ x_dec = x0 * mask + (1. - mask) * x_dec
|
||||||
|
|
||||||
|
- return x_dec
|
||||||
|
+ if streaming_callbacks and img_callback:
|
||||||
|
+ yield from img_callback(x_dec, len(iterator)-1)
|
||||||
|
+ else:
|
||||||
|
+ return x_dec
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os, re
|
import os, re
|
||||||
import traceback
|
import traceback
|
||||||
import torch
|
import torch
|
||||||
@ -332,15 +333,23 @@ def mk_img(req: Request):
|
|||||||
|
|
||||||
partial_x_samples = x_samples
|
partial_x_samples = x_samples
|
||||||
|
|
||||||
|
if req.stream_progress_updates:
|
||||||
|
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
|
||||||
|
|
||||||
if stop_processing:
|
if stop_processing:
|
||||||
raise UserInitiatedStop("User requested that we stop processing")
|
raise UserInitiatedStop("User requested that we stop processing")
|
||||||
|
|
||||||
# run the handler
|
# run the handler
|
||||||
try:
|
try:
|
||||||
if handler == _txt2img:
|
if handler == _txt2img:
|
||||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
|
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
|
||||||
else:
|
else:
|
||||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
|
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
|
||||||
|
|
||||||
|
if req.stream_progress_updates:
|
||||||
|
yield from x_samples
|
||||||
|
|
||||||
|
x_samples = partial_x_samples
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
if partial_x_samples is None:
|
if partial_x_samples is None:
|
||||||
continue
|
continue
|
||||||
@ -421,6 +430,9 @@ def mk_img(req: Request):
|
|||||||
del x_samples
|
del x_samples
|
||||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||||
|
|
||||||
|
if req.stream_progress_updates:
|
||||||
|
yield json.dumps(res.json())
|
||||||
|
else:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def save_image(img, img_out_path):
|
def save_image(img, img_out_path):
|
||||||
@ -438,7 +450,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
|||||||
except:
|
except:
|
||||||
print('could not save the file', traceback.format_exc())
|
print('could not save the file', traceback.format_exc())
|
||||||
|
|
||||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
|
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
|
||||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
if device != "cpu":
|
if device != "cpu":
|
||||||
@ -458,12 +470,16 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
eta=opt_ddim_eta,
|
eta=opt_ddim_eta,
|
||||||
x_T=start_code,
|
x_T=start_code,
|
||||||
img_callback=img_callback,
|
img_callback=img_callback,
|
||||||
|
streaming_callbacks=streaming_callbacks,
|
||||||
sampler = 'plms',
|
sampler = 'plms',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if streaming_callbacks:
|
||||||
|
yield from samples_ddim
|
||||||
|
else:
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
|
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = model.stochastic_encode(
|
z_enc = model.stochastic_encode(
|
||||||
init_latent,
|
init_latent,
|
||||||
@ -480,9 +496,13 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
|||||||
unconditional_guidance_scale=opt_scale,
|
unconditional_guidance_scale=opt_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
img_callback=img_callback,
|
img_callback=img_callback,
|
||||||
|
streaming_callbacks=streaming_callbacks,
|
||||||
sampler = 'ddim'
|
sampler = 'ddim'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if streaming_callbacks:
|
||||||
|
yield from samples_ddim
|
||||||
|
else:
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
def gc():
|
def gc():
|
||||||
|
13
ui/server.py
13
ui/server.py
@ -15,7 +15,7 @@ CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
|
|||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from starlette.responses import FileResponse
|
from starlette.responses import FileResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -50,6 +50,8 @@ class ImageRequest(BaseModel):
|
|||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
|
|
||||||
|
stream_progress_updates: bool = False
|
||||||
|
|
||||||
class SetAppConfigRequest(BaseModel):
|
class SetAppConfigRequest(BaseModel):
|
||||||
update_branch: str = "main"
|
update_branch: str = "main"
|
||||||
|
|
||||||
@ -106,9 +108,14 @@ def image(req : ImageRequest):
|
|||||||
r.use_face_correction = req.use_face_correction
|
r.use_face_correction = req.use_face_correction
|
||||||
r.show_only_filtered_image = req.show_only_filtered_image
|
r.show_only_filtered_image = req.show_only_filtered_image
|
||||||
|
|
||||||
try:
|
r.stream_progress_updates = req.stream_progress_updates
|
||||||
res: Response = runtime.mk_img(r)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = runtime.mk_img(r)
|
||||||
|
|
||||||
|
if r.stream_progress_updates:
|
||||||
|
return StreamingResponse(res, media_type='application/json')
|
||||||
|
else:
|
||||||
return res.json()
|
return res.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
Loading…
Reference in New Issue
Block a user