Merge pull request #346 from cmdr2/task_manager

Allow multiple tabs and computers to generate tasks without throwing errors (by @madrang)
This commit is contained in:
cmdr2 2022-10-17 17:57:10 +05:30 committed by GitHub
commit 1c171d0f12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 662 additions and 298 deletions

View File

@ -16,7 +16,7 @@
<div id="container"> <div id="container">
<div id="top-nav"> <div id="top-nav">
<div id="logo"> <div id="logo">
<h1>Stable Diffusion UI <small>v2.26 <span id="updateBranchLabel"></span></small></h1> <h1>Stable Diffusion UI <small>v2.27 <span id="updateBranchLabel"></span></small></h1>
</div> </div>
<ul id="top-nav-items"> <ul id="top-nav-items">
<li class="dropdown"> <li class="dropdown">

View File

@ -326,13 +326,13 @@ img {
height: 8pt; height: 8pt;
border-radius: 4pt; */ border-radius: 4pt; */
font-size: 14pt; font-size: 14pt;
color: rgb(128, 87, 0); color: rgb(200, 139, 0);
/* background-color: rgb(197, 1, 1); */ /* background-color: rgb(197, 1, 1); */
/* transform: translateY(15%); */ /* transform: translateY(15%); */
display: inline; display: inline;
} }
#server-status-msg { #server-status-msg {
color: rgb(128, 87, 0); color: rgb(200, 139, 0);
padding-left: 2pt; padding-left: 2pt;
font-size: 10pt; font-size: 10pt;
} }
@ -479,7 +479,12 @@ img {
.activeTaskLabel { .activeTaskLabel {
background:rgb(0, 90, 30); background:rgb(0, 90, 30);
border: 1px solid rgb(0, 75, 19); border: 1px solid rgb(0, 75, 19);
color:rgb(204, 255, 217) color:rgb(222, 253, 230)
}
.waitingTaskLabel {
background:rgb(128, 89, 0);
border: 1px solid rgb(0, 75, 19);
color:rgb(255, 242, 211)
} }
.secondaryButton { .secondaryButton {
background: rgb(132, 8, 0); background: rgb(132, 8, 0);

View File

@ -20,7 +20,7 @@ const INPAINTING_EDITOR_SIZE = 450
const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64') const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64')
let sessionId = new Date().getTime() let sessionId = Date.now()
let promptField = document.querySelector('#prompt') let promptField = document.querySelector('#prompt')
let promptsFromFileSelector = document.querySelector('#prompt_from_file') let promptsFromFileSelector = document.querySelector('#prompt_from_file')
@ -122,7 +122,7 @@ maskResetButton.innerHTML = 'Clear'
maskResetButton.style.fontWeight = 'normal' maskResetButton.style.fontWeight = 'normal'
maskResetButton.style.fontSize = '10pt' maskResetButton.style.fontSize = '10pt'
let serverStatus = 'offline' let serverState = {'status': 'Offline', 'time': Date.now()}
let activeTags = [] let activeTags = []
let modifiers = [] let modifiers = []
let lastPromptUsed = '' let lastPromptUsed = ''
@ -225,21 +225,38 @@ function getOutputFormat() {
} }
function setStatus(statusType, msg, msgType) { function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') { }
return
}
if (msgType == 'error') { function setServerStatus(msgType, msg) {
// msg = '<span style="color: red">' + msg + '<span>' switch(msgType) {
case 'online':
serverStatusColor.style.color = 'green'
serverStatusMsg.style.color = 'green'
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
break
case 'busy':
serverStatusColor.style.color = 'rgb(200, 139, 0)'
serverStatusMsg.style.color = 'rgb(200, 139, 0)'
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
break
case 'error':
serverStatusColor.style.color = 'red' serverStatusColor.style.color = 'red'
serverStatusMsg.style.color = 'red' serverStatusMsg.style.color = 'red'
serverStatusMsg.innerText = 'Stable Diffusion has stopped' serverStatusMsg.innerText = 'Stable Diffusion has stopped'
} else if (msgType == 'success') { break
// msg = '<span style="color: green">' + msg + '<span>' }
serverStatusColor.style.color = 'green' }
serverStatusMsg.style.color = 'green' function isServerAvailable() {
serverStatusMsg.innerText = 'Stable Diffusion is ready' if (typeof serverState !== 'object') {
serverStatus = 'online' return false
}
switch (serverState.status) {
case 'LoadingModel':
case 'Rendering':
case 'Online':
return true
default:
return false
} }
} }
@ -263,6 +280,11 @@ function logError(msg, res, outputMsg) {
console.log('request error', res) console.log('request error', res)
setStatus('request', 'error', 'error') setStatus('request', 'error', 'error')
} }
function asyncDelay(timeout) {
return new Promise(function(resolve, reject) {
setTimeout(resolve, timeout, true)
})
}
function playSound() { function playSound() {
const audio = new Audio('/media/ding.mp3') const audio = new Audio('/media/ding.mp3')
@ -277,16 +299,40 @@ function playSound() {
async function healthCheck() { async function healthCheck() {
try { try {
let res = await fetch('/ping') let res = undefined
res = await res.json() if (sessionId) {
res = await fetch('/ping?session_id=' + sessionId)
if (res[0] == 'OK') {
setStatus('server', 'online', 'success')
} else { } else {
setStatus('server', 'offline', 'error') res = await fetch('/ping')
} }
serverState = await res.json()
if (typeof serverState !== 'object' || typeof serverState.status !== 'string') {
serverState = {'status': 'Offline', 'time': Date.now()}
setServerStatus('error', 'offline')
return
}
// Set status
switch(serverState.status) {
case 'Init':
// Wait for init to complete before updating status.
break
case 'Online':
setServerStatus('online', 'ready')
break
case 'LoadingModel':
setServerStatus('busy', 'loading model')
break
case 'Rendering':
setServerStatus('busy', 'rendering')
break
default: // Unavailable
setServerStatus('error', serverState.status.toLowerCase())
break
}
serverState.time = Date.now()
} catch (e) { } catch (e) {
setStatus('server', 'offline', 'error') serverState = {'status': 'Offline', 'time': Date.now()}
setServerStatus('error', 'offline')
} }
} }
function resizeInpaintingEditor() { function resizeInpaintingEditor() {
@ -329,7 +375,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
if(typeof res != 'object') return if(typeof res != 'object') return
res.output.reverse() res.output.reverse()
res.output.forEach((result, index) => { res.output.forEach((result, index) => {
const imageData = result?.data || result?.path + '?t=' + new Date().getTime(), const imageData = result?.data || result?.path + '?t=' + Date.now(),
imageSeed = result?.seed, imageSeed = result?.seed,
imagePrompt = reqBody.prompt, imagePrompt = reqBody.prompt,
imageInferenceSteps = reqBody.num_inference_steps, imageInferenceSteps = reqBody.num_inference_steps,
@ -440,8 +486,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) {
} }
function getStartNewTaskHandler(reqBody, imageItemElem, mode) { function getStartNewTaskHandler(reqBody, imageItemElem, mode) {
return function() { return function() {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
alert('The server is still starting up..') alert('The server is not available.')
return return
} }
const imageElem = imageItemElem.querySelector('img') const imageElem = imageItemElem.querySelector('img')
@ -507,37 +553,76 @@ async function doMakeImage(task) {
const progressBar = task['progressBar'] const progressBar = task['progressBar']
let res = undefined let res = undefined
let stepUpdate = undefined
try { try {
res = await fetch('/image', { const lastTask = serverState.task
let renderRequest = undefined
do {
res = await fetch('/render', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify(reqBody) body: JSON.stringify(reqBody)
}) })
renderRequest = await res.json()
// status_code 503, already a task running.
} while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000))
if (typeof renderRequest?.stream !== 'string') {
console.log('Endpoint response: ', renderRequest)
throw new Error('Endpoint response does not contains a response stream url.')
}
task['taskStatusLabel'].innerText = "Waiting"
task['taskStatusLabel'].classList.add('waitingTaskLabel')
task['taskStatusLabel'].classList.remove('activeTaskLabel')
do { // Wait for server status to update.
await asyncDelay(250)
if (!isServerAvailable()) {
throw new Error('Connexion with server lost.')
}
} while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task)
if (serverState.session !== 'pending' && serverState.session !== 'running' && serverState.session !== 'buffer') {
if (serverState.session === 'stopped') {
return false
}
throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
}
while (serverState.task === renderRequest.task && serverState.session === 'pending') {
// Wait for task to start on server.
await asyncDelay(1500)
}
// Task started!
res = await fetch(renderRequest.stream, {
headers: {
'Content-Type': 'application/json'
},
})
task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].classList.add('activeTaskLabel')
task['taskStatusLabel'].classList.remove('waitingTaskLabel')
let stepUpdate = undefined
let reader = res.body.getReader() let reader = res.body.getReader()
let textDecoder = new TextDecoder() let textDecoder = new TextDecoder()
let finalJSON = '' let finalJSON = ''
let prevTime = -1 let prevTime = -1
let readComplete = false let readComplete = false
while (true) { while (!readComplete || finalJSON.length > 0) {
let t = new Date().getTime() let t = Date.now()
let jsonStr = '' let jsonStr = ''
if (!readComplete) { if (!readComplete) {
const {value, done} = await reader.read() const {value, done} = await reader.read()
if (done) { if (done) {
readComplete = true readComplete = true
} }
if (done && finalJSON.length <= 0 && !value) {
break
}
if (value) { if (value) {
jsonStr = textDecoder.decode(value) jsonStr = textDecoder.decode(value)
} }
} }
stepUpdate = undefined
try { try {
// hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot. // hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot.
// this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste... // this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste...
@ -571,9 +656,6 @@ async function doMakeImage(task) {
throw e throw e
} }
} }
if (readComplete && finalJSON.length <= 0) {
break
}
if (typeof stepUpdate === 'object' && 'step' in stepUpdate) { if (typeof stepUpdate === 'object' && 'step' in stepUpdate) {
let batchSize = stepUpdate.total_steps let batchSize = stepUpdate.total_steps
let overallStepCount = stepUpdate.step + task.batchesDone * batchSize let overallStepCount = stepUpdate.step + task.batchesDone * batchSize
@ -598,6 +680,23 @@ async function doMakeImage(task) {
showImages(reqBody, stepUpdate, outputContainer, true) showImages(reqBody, stepUpdate, outputContainer, true)
} }
} }
if (stepUpdate?.status) {
break
}
if (readComplete && finalJSON.length <= 0) {
if (res.status === 200) {
await asyncDelay(5000)
res = await fetch(renderRequest.stream, {
headers: {
'Content-Type': 'application/json'
},
})
reader = res.body.getReader()
readComplete = false
} else {
console.log('Stream stopped: ', res)
}
}
prevTime = t prevTime = t
} }
@ -614,27 +713,28 @@ async function doMakeImage(task) {
3. Try generating a smaller image.<br/>` 3. Try generating a smaller image.<br/>`
} }
} else { } else {
msg = `Unexpected Read Error:<br/><pre>StepUpdate:${JSON.stringify(stepUpdate, undefined, 4)}</pre>` msg = `Unexpected Read Error:<br/><pre>StepUpdate: ${JSON.stringify(stepUpdate, undefined, 4)}</pre>`
} }
logError(msg, res, outputMsg) logError(msg, res, outputMsg)
return false return false
} }
if (typeof stepUpdate !== 'object' || !res || res.status != 200) { if (typeof stepUpdate !== 'object' || !res || res.status != 200) {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg) logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg)
} else if (typeof res === 'object') { } else if (typeof res === 'object') {
let msg = 'Stable Diffusion had an error reading the response: ' let msg = 'Stable Diffusion had an error reading the response: '
try { // 'Response': body stream already read try { // 'Response': body stream already read
msg += 'Read: ' + await res.text() msg += 'Read: ' + await res.text()
} catch(e) { } catch(e) {
msg += 'No error response. ' msg += 'Unexpected end of stream. '
} }
if (finalJSON) { if (finalJSON) {
msg += 'Buffered data: ' + finalJSON msg += 'Buffered data: ' + finalJSON
} }
logError(msg, res, outputMsg) logError(msg, res, outputMsg)
} else { } else {
msg = `Unexpected Read Error:<br/><pre>Response:${res}<br/>StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>` let msg = `Unexpected Read Error:<br/><pre>Response: ${res}<br/>StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
logError(msg, res, outputMsg)
} }
progressBar.style.display = 'none' progressBar.style.display = 'none'
return false return false
@ -682,14 +782,14 @@ async function checkTasks() {
let task = taskQueue.pop() let task = taskQueue.pop()
currentTask = task currentTask = task
let time = new Date().getTime() let time = Date.now()
let successCount = 0 let successCount = 0
task.isProcessing = true task.isProcessing = true
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop' task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing" task['taskStatusLabel'].innerText = "Starting"
task['taskStatusLabel'].className += " activeTaskLabel" task['taskStatusLabel'].classList.add('waitingTaskLabel')
const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1)) const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1))
const startSeed = task.reqBody.seed || task.seed const startSeed = task.reqBody.seed || task.seed
@ -724,7 +824,7 @@ async function checkTasks() {
task['stopTask'].innerHTML = '<i class="fa-solid fa-trash-can"></i> Remove' task['stopTask'].innerHTML = '<i class="fa-solid fa-trash-can"></i> Remove'
task['taskStatusLabel'].style.display = 'none' task['taskStatusLabel'].style.display = 'none'
time = new Date().getTime() - time time = Date.now() - time
time /= 1000 time /= 1000
if (successCount === task.batchCount) { if (successCount === task.batchCount) {
@ -814,8 +914,8 @@ function getCurrentUserRequest() {
} }
function makeImage() { function makeImage() {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
alert('The server is still starting up..') alert('The server is not available.')
return return
} }
const taskTemplate = getCurrentUserRequest() const taskTemplate = getCurrentUserRequest()
@ -868,7 +968,7 @@ function createTask(task) {
if (task['isProcessing']) { if (task['isProcessing']) {
task.isProcessing = false task.isProcessing = false
try { try {
let res = await fetch('/image/stop') let res = await fetch('/image/stop?session_id=' + sessionId)
} catch (e) { } catch (e) {
console.log(e) console.log(e)
} }
@ -1008,7 +1108,7 @@ async function stopAllTasks() {
} }
try { try {
let res = await fetch('/image/stop') let res = await fetch('/image/stop?session_id=' + sessionId)
} catch (e) { } catch (e) {
console.log(e) console.log(e)
} }
@ -1135,9 +1235,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength() updatePromptStrength()
useBetaChannelField.addEventListener('click', async function(e) { useBetaChannelField.addEventListener('click', async function(e) {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
// logError('The server is still starting up..') // logError('The server is still starting up..')
alert('The server is still starting up..') alert('The server is not available.')
e.preventDefault() e.preventDefault()
return false return false
} }
@ -1164,7 +1264,7 @@ useBetaChannelField.addEventListener('click', async function(e) {
async function getAppConfig() { async function getAppConfig() {
try { try {
let res = await fetch('/app_config') let res = await fetch('/get/app_config')
const config = await res.json() const config = await res.json()
if (config.update_branch === 'beta') { if (config.update_branch === 'beta') {
@ -1180,7 +1280,7 @@ async function getAppConfig() {
async function getModels() { async function getModels() {
try { try {
let res = await fetch('/models') let res = await fetch('/get/models')
const models = await res.json() const models = await res.json()
let activeModel = models['active'] let activeModel = models['active']
@ -1451,10 +1551,10 @@ async function getDiskPath() {
return return
} }
let res = await fetch('/output_dir') let res = await fetch('/get/output_dir')
if (res.status === 200) { if (res.status === 200) {
res = await res.json() res = await res.json()
res = res[0] res = res.output_dir
document.querySelector('#diskPath').value = res document.querySelector('#diskPath').value = res
} }
@ -1562,14 +1662,15 @@ function resizeModifierCards(val) {
const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix)) const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix))
card.className = classes.join(' ').trim() card.className = classes.join(' ').trim()
if(val != 0) if(val != 0) {
card.classList.add(cardSize(val)) card.classList.add(cardSize(val))
}
}) })
} }
async function loadModifiers() { async function loadModifiers() {
try { try {
let res = await fetch('/modifiers.json?v=2') let res = await fetch('/get/modifiers')
if (res.status === 200) { if (res.status === 200) {
res = await res.json() res = await res.json()

View File

@ -197,6 +197,35 @@ def load_model_real_esrgan(real_esrgan_to_use):
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision) print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
if disk_path is None: return None
if session_id is None: return None
if ext is None: raise Exception('Missing ext')
session_out_path = os.path.join(disk_path, session_id)
os.makedirs(session_out_path, exist_ok=True)
prompt_flattened = filename_regex.sub('_', prompt)[:50]
img_id = str(uuid.uuid4())[-8:]
if suffix is not None:
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
def apply_filters(filter_name, image_data):
print(f'Applying filter {filter_name}...')
gc()
if filter_name == 'gfpgan':
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1]
return image_data
def mk_img(req: Request): def mk_img(req: Request):
try: try:
yield from do_mk_img(req) yield from do_mk_img(req)
@ -283,23 +312,11 @@ def do_mk_img(req: Request):
opt_prompt = req.prompt opt_prompt = req.prompt
opt_seed = req.seed opt_seed = req.seed
opt_n_samples = req.num_outputs
opt_n_iter = 1 opt_n_iter = 1
opt_scale = req.guidance_scale
opt_C = 4 opt_C = 4
opt_H = req.height
opt_W = req.width
opt_f = 8 opt_f = 8
opt_ddim_steps = req.num_inference_steps
opt_ddim_eta = 0.0 opt_ddim_eta = 0.0
opt_strength = req.prompt_strength
opt_save_to_disk_path = req.save_to_disk_path
opt_init_img = req.init_image opt_init_img = req.init_image
opt_use_face_correction = req.use_face_correction
opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image
opt_format = req.output_format
opt_sampler_name = req.sampler
print(req.to_string(), '\n device', device) print(req.to_string(), '\n device', device)
@ -307,7 +324,7 @@ def do_mk_img(req: Request):
seed_everything(opt_seed) seed_everything(opt_seed)
batch_size = opt_n_samples batch_size = req.num_outputs
prompt = opt_prompt prompt = opt_prompt
assert prompt is not None assert prompt is not None
data = [batch_size * [prompt]] data = [batch_size * [prompt]]
@ -327,7 +344,7 @@ def do_mk_img(req: Request):
else: else:
handler = _img2img handler = _img2img
init_image = load_img(req.init_image, opt_W, opt_H) init_image = load_img(req.init_image, req.width, req.height)
init_image = init_image.to(device) init_image = init_image.to(device)
if device != "cpu" and precision == "autocast": if device != "cpu" and precision == "autocast":
@ -339,7 +356,7 @@ def do_mk_img(req: Request):
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None: if req.mask is not None:
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device) mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size) mask = repeat(mask, '1 ... -> b ...', b=batch_size)
@ -348,12 +365,12 @@ def do_mk_img(req: Request):
move_fs_to_cpu() move_fs_to_cpu()
assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(opt_strength * opt_ddim_steps) t_enc = int(req.prompt_strength * req.num_inference_steps)
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
if opt_save_to_disk_path is not None: if req.save_to_disk_path is not None:
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id) session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True) os.makedirs(session_out_path, exist_ok=True)
else: else:
session_out_path = None session_out_path = None
@ -366,7 +383,7 @@ def do_mk_img(req: Request):
with precision_scope("cuda"): with precision_scope("cuda"):
modelCS.to(device) modelCS.to(device)
uc = None uc = None
if opt_scale != 1.0: if req.guidance_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -393,7 +410,7 @@ def do_mk_img(req: Request):
partial_x_samples = x_samples partial_x_samples = x_samples
if req.stream_progress_updates: if req.stream_progress_updates:
n_steps = opt_ddim_steps if req.init_image is None else t_enc n_steps = req.num_inference_steps if req.init_image is None else t_enc
progress = {"step": i, "total_steps": n_steps} progress = {"step": i, "total_steps": n_steps}
if req.stream_image_progress and i % 5 == 0: if req.stream_image_progress and i % 5 == 0:
@ -425,9 +442,9 @@ def do_mk_img(req: Request):
# run the handler # run the handler
try: try:
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name) x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask) x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
yield from x_samples yield from x_samples
@ -447,68 +464,48 @@ def do_mk_img(req: Request):
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample) img = Image.fromarray(x_sample)
has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \ has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')) (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
return_orig_img = not has_filters or not opt_show_only_filtered return_orig_img = not has_filters or not req.show_only_filtered_image
if stop_processing: if stop_processing:
return_orig_img = True return_orig_img = True
if opt_save_to_disk_path is not None: if req.save_to_disk_path is not None:
prompt_flattened = filename_regex.sub('_', prompts[0])
prompt_flattened = prompt_flattened[:50]
img_id = str(uuid.uuid4())[-8:]
file_path = f"{prompt_flattened}_{img_id}"
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
if return_orig_img: if return_orig_img:
img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format)
save_image(img, img_out_path) save_image(img, img_out_path)
meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], 'txt')
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file) save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img: if return_orig_img:
img_data = img_to_base64_str(img, opt_format) img_data = img_to_base64_str(img, req.output_format)
res_image_orig = ResponseImage(data=img_data, seed=opt_seed) res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
res.images.append(res_image_orig) res.images.append(res_image_orig)
if opt_save_to_disk_path is not None: if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path res_image_orig.path_abs = img_out_path
del img del img
if has_filters and not stop_processing: if has_filters and not stop_processing:
print('Applying filters..')
gc()
filters_applied = [] filters_applied = []
if req.use_face_correction:
if opt_use_face_correction: x_sample = apply_filters('gfpgan', x_sample)
_, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) filters_applied.append(req.use_face_correction)
x_sample = output[:,:,::-1] if req.use_upscale:
filters_applied.append(opt_use_face_correction) x_sample = apply_filters('real_esrgan', x_sample)
filters_applied.append(req.use_upscale)
if opt_use_upscale: if (len(filters_applied) > 0):
output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1])
x_sample = output[:,:,::-1]
filters_applied.append(opt_use_upscale)
filtered_image = Image.fromarray(x_sample) filtered_image = Image.fromarray(x_sample)
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
filtered_img_data = img_to_base64_str(filtered_image, opt_format) response_image = ResponseImage(data=filtered_img_data, seed=req.seed)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed) res.images.append(response_image)
res.images.append(res_image_filtered) if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format, "_".join(filters_applied))
filters_applied = "_".join(filters_applied)
if opt_save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
save_image(filtered_image, filtered_img_out_path) save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path response_image.path_abs = filtered_img_out_path
del filtered_image del filtered_image
seeds += str(opt_seed) + "," seeds += str(opt_seed) + ","
@ -529,9 +526,20 @@ def save_image(img, img_out_path):
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file): def save_metadata(meta_out_path, req, prompt, opt_seed):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}" metadata = f"""{prompt}
Width: {req.width}
Height: {req.height}
Seed: {opt_seed}
Steps: {req.num_inference_steps}
Guidance Scale: {req.guidance_scale}
Prompt Strength: {req.prompt_strength}
Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale}
Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
"""
try: try:
with open(meta_out_path, 'w') as f: with open(meta_out_path, 'w') as f:
f.write(metadata) f.write(metadata)

View File

@ -0,0 +1,298 @@
import json
import traceback
TASK_TTL = 15 * 60 # Discard last session's task timeout
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel
from sd_internal import Request, Response
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
class Symbol(metaclass=SymbolClass): pass
class ServerStates:
class Init(Symbol): pass
class LoadingModel(Symbol): pass
class Online(Symbol): pass
class Rendering(Symbol): pass
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
async def read_buffer_generator(self):
try:
while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done()
yield res
except queue.Empty as e: yield
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False
stream_image_progress: bool = False
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
self._lock: threading.Lock = threading.RLock()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
try:
# Create a list of expired keys to delete
to_delete = []
for key in self._base:
ttl, _ = self._base[key]
if self._is_expired(ttl):
to_delete.append(key)
# Remove Items
for key in to_delete:
del self._base[key]
print(f'Session {key} expired. Data removed.')
finally:
self._lock.release()
def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
try: self._base.clear()
finally: self._lock.release()
def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
try:
if key not in self._base:
return False
del self._base[key]
return True
finally:
self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
try:
if key in self._base:
_, value = self._base.get(key)
self._base[key] = (self._get_ttl_time(ttl), value)
return True
return False
finally:
self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
try:
self._base[key] = (
self._get_ttl_time(ttl), value
)
except Exception as e:
print(str(e))
print(traceback.format_exc())
return False
else:
return True
finally:
self._lock.release()
def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.')
self.delete(key)
return None
return value
finally:
self._lock.release()
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
tasks_queue = queue.Queue()
task_cache = TaskCache()
default_model_to_load = None
def preload_model(file_path=None):
global current_state, current_state_error, current_model_path
if file_path == None:
file_path = default_model_to_load
if file_path == current_model_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.load_model_ckpt(ckpt_to_use=file_path)
current_model_path = file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def thread_render():
global current_state, current_state_error, current_model_path
from . import runtime
current_state = ServerStates.Online
preload_model()
while True:
task_cache.clean()
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
task = None
try:
task = tasks_queue.get(timeout=1)
except queue.Empty as e:
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
else: continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error:
task.error = current_state_error
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
try:
task.lock.acquire(blocking=False)
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel
except Exception as e:
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
task_cache.keep(task.request.session_id, TASK_TTL)
# Task completed
task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
render_thread = threading.Thread(target=thread_render)
def start_render_thread():
# Start Rendering Thread
render_thread.daemon = True
render_thread.start()
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
if task and not task.response and not task.error and not task.lock.locked():
# Unstarted task pending, deny queueing more than one.
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
#
from . import runtime
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_cpu = req.use_cpu
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
if task_cache.put(r.session_id, new_task, TASK_TTL):
tasks_queue.put(new_task, block=True, timeout=30)
return new_task
raise RuntimeError('Failed to add task to cache.')

View File

@ -14,90 +14,32 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
from sd_internal import Request, Response from sd_internal import Request, Response, task_manager
app = FastAPI() app = FastAPI()
model_loaded = False
model_is_loading = False
modifiers_cache = None modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# don't show access log entries for URLs that start with the given prefix # don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails'] ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
stream_progress_updates: bool = False
stream_image_progress: bool = False
class SetAppConfigRequest(BaseModel): class SetAppConfigRequest(BaseModel):
update_branch: str = "main" update_branch: str = "main"
@app.get('/')
def read_root():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
@app.get('/ping')
async def ping():
global model_loaded, model_is_loading
try:
if model_loaded:
return {'OK'}
if model_is_loading:
return {'ERROR'}
model_is_loading = True
from sd_internal import runtime
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
model_loaded = True
model_is_loading = False
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
# needs to support the legacy installations # needs to support the legacy installations
def get_initial_model_to_load(): def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
@ -114,7 +56,6 @@ def get_initial_model_to_load():
ckpt_to_use = model_path ckpt_to_use = model_path
else: else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt') print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use return ckpt_to_use
def resolve_model_to_use(model_name): def resolve_model_to_use(model_name):
@ -126,92 +67,110 @@ def resolve_model_to_use(model_name):
model_path = legacy_model_path model_path = legacy_model_path
else: else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path return model_path
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if not task_manager.render_thread.is_alive(): # Render thread is dead.
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
return HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
task = task_manager.task_cache.tryGet(session_id)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(model_name): def save_model_to_config(model_name):
config = getConfig() config = getConfig()
if 'model' not in config: if 'model' not in config:
config['model'] = {} config['model'] = {}
config['model']['stable-diffusion'] = model_name config['model']['stable-diffusion'] = model_name
setConfig(config) setConfig(config)
@app.post('/image') @app.post('/render')
def image(req : ImageRequest): def render(req : task_manager.ImageRequest):
from sd_internal import runtime
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_cpu = req.use_cpu
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
save_model_to_config(req.use_stable_diffusion_model)
try: try:
if not req.stream_progress_updates: save_model_to_config(req.use_stable_diffusion_model)
r.stream_image_progress = False req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
new_task = task_manager.render(req)
res = runtime.mk_img(r) response = {
'status': str(task_manager.current_state),
if req.stream_progress_updates: 'queue': task_manager.tasks_queue.qsize(),
return StreamingResponse(res, media_type='application/json') 'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
else: # compatibility mode: buffer the streaming responses, and return the last one 'task': id(new_task)
last_result = None }
return JSONResponse(response, headers=NOCACHE_HEADERS)
for result in res: except ChildProcessError as e: # Render thread is dead
last_result = result return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
return json.loads(last_result) return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
except Exception as e: except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_manager.task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop') @app.get('/image/stop')
def stop(): def stop(session_id:str=None):
try: if not session_id:
if model_is_loading: if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
return {'ERROR'} return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
from sd_internal import runtime return {'OK'}
runtime.stop_processing = True task = task_manager.task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'} return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/image/tmp/{session_id}/{img_id}') @app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id): def get_image(session_id, img_id):
from sd_internal import runtime task = task_manager.task_cache.tryGet(session_id)
buf = runtime.temp_images[session_id + '/' + img_id] if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
buf.seek(0) if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
return StreamingResponse(buf, media_type='image/jpeg') try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config') @app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
@ -242,42 +201,27 @@ async def setAppConfig(req : SetAppConfigRequest):
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@app.get('/app_config') def getConfig(default_val={}):
def getAppConfig():
try: try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path): if not os.path.exists(config_json_path):
return HTTPException(status_code=500, detail="No config file") return default_val
with open(config_json_path, 'r') as f: with open(config_json_path, 'r') as f:
return json.load(f) return json.load(f)
except Exception as e: except Exception as e:
print(str(e))
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return default_val
def getConfig():
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return {}
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
return {}
def setConfig(config): def setConfig(config):
try: try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f: with open(config_json_path, 'w') as f:
return json.dump(config, f) return json.dump(config, f)
except: except:
print(str(e))
print(traceback.format_exc()) print(traceback.format_exc())
@app.get('/models')
def getModels(): def getModels():
models = { models = {
'active': { 'active': {
@ -307,14 +251,21 @@ def getModels():
return models return models
@app.get('/modifiers.json') @app.get('/get/{key:path}')
def read_modifiers(): def read_web_data(key:str=None):
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} if not key: # /get without parameters, stable-diffusion easter egg.
return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=headers) return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
@app.get('/output_dir') config = getConfig(default_val=None)
def read_home_dir(): if config is None:
return {outpath} return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else:
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
# don't log certain requests # don't log certain requests
class LogSuppressFilter(logging.Filter): class LogSuppressFilter(logging.Filter):
@ -323,10 +274,11 @@ class LogSuppressFilter(logging.Filter):
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1: if path.find(prefix) != -1:
return False return False
return True return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
task_manager.default_model_to_load = get_initial_model_to_load()
task_manager.start_render_thread()
# start the browser ui # start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000') import webbrowser; webbrowser.open('http://localhost:9000')