Render queue first draft

This commit is contained in:
Marc-Andre Ferland 2022-10-14 03:47:25 -04:00
parent 4b88cfa51a
commit 1ec9d986bb
2 changed files with 393 additions and 116 deletions

View File

@ -117,7 +117,7 @@ maskResetButton.innerHTML = 'Clear'
maskResetButton.style.fontWeight = 'normal' maskResetButton.style.fontWeight = 'normal'
maskResetButton.style.fontSize = '10pt' maskResetButton.style.fontSize = '10pt'
let serverStatus = 'offline' let serverState = {'status': 'Offline'}
let activeTags = [] let activeTags = []
let modifiers = [] let modifiers = []
let lastPromptUsed = '' let lastPromptUsed = ''
@ -212,21 +212,38 @@ function getOutputFormat() {
} }
function setStatus(statusType, msg, msgType) { function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') {
return
} }
if (msgType == 'error') { function setServerStatus(msgType, msg) {
// msg = '<span style="color: red">' + msg + '<span>' switch(msgType) {
case 'online':
serverStatusColor.style.color = 'green'
serverStatusMsg.style.color = 'green'
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
break
case 'busy':
serverStatusColor.style.color = 'yellow'
serverStatusMsg.style.color = 'yellow'
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
break
case 'error':
serverStatusColor.style.color = 'red' serverStatusColor.style.color = 'red'
serverStatusMsg.style.color = 'red' serverStatusMsg.style.color = 'red'
serverStatusMsg.innerText = 'Stable Diffusion has stopped' serverStatusMsg.innerText = 'Stable Diffusion has stopped'
} else if (msgType == 'success') { break
// msg = '<span style="color: green">' + msg + '<span>' }
serverStatusColor.style.color = 'green' }
serverStatusMsg.style.color = 'green' function isServerAvailable() {
serverStatusMsg.innerText = 'Stable Diffusion is ready' if (typeof serverState !== 'object') {
serverStatus = 'online' return false
}
switch (serverState.status) {
case 'LoadingModel':
case 'Rendering':
case 'Online':
return true
default:
return false
} }
} }
@ -250,6 +267,11 @@ function logError(msg, res, outputMsg) {
console.log('request error', res) console.log('request error', res)
setStatus('request', 'error', 'error') setStatus('request', 'error', 'error')
} }
function asyncDelay(timeout) {
return new Promise(function(resolve, reject) {
setTimeout(resolve, timeout, true)
})
}
function playSound() { function playSound() {
const audio = new Audio('/media/ding.mp3') const audio = new Audio('/media/ding.mp3')
@ -259,16 +281,34 @@ function playSound() {
async function healthCheck() { async function healthCheck() {
try { try {
let res = await fetch('/ping') let res = undefined
res = await res.json() if (sessionId) {
res = await fetch('/ping?session_id=' + sessionId)
if (res[0] == 'OK') {
setStatus('server', 'online', 'success')
} else { } else {
setStatus('server', 'offline', 'error') res = await fetch('/ping')
} }
serverState = await res.json()
// Set status
switch(serverState.status) {
case 'Init':
// Wait for init to complete before updating status.
break
case 'Online':
setServerStatus('online', 'ready')
break
case 'LoadingModel':
setServerStatus('busy', 'loading model')
break
case 'Rendering':
setServerStatus('busy', 'rendering')
break
default: // Unavailable
setServerStatus('error', serverState.status.toLowerCase())
break
}
serverState.time = Date.now()
} catch (e) { } catch (e) {
setStatus('server', 'offline', 'error') setServerStatus('error', 'offline')
} }
} }
function resizeInpaintingEditor() { function resizeInpaintingEditor() {
@ -311,7 +351,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
if(typeof res != 'object') return if(typeof res != 'object') return
res.output.reverse() res.output.reverse()
res.output.forEach((result, index) => { res.output.forEach((result, index) => {
const imageData = result?.data || result?.path + '?t=' + new Date().getTime() const imageData = result?.data || result?.path + '?t=' + Date.now()
const imageWidth = reqBody.width const imageWidth = reqBody.width
const imageHeight = reqBody.height const imageHeight = reqBody.height
if (!imageData.includes('/')) { if (!imageData.includes('/')) {
@ -409,8 +449,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) {
} }
function getStartNewTaskHandler(reqBody, imageItemElem, mode) { function getStartNewTaskHandler(reqBody, imageItemElem, mode) {
return function() { return function() {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
alert('The server is still starting up..') alert('The server is not available.')
return return
} }
const imageElem = imageItemElem.querySelector('img') const imageElem = imageItemElem.querySelector('img')
@ -473,37 +513,68 @@ async function doMakeImage(task) {
const progressBar = task['progressBar'] const progressBar = task['progressBar']
let res = undefined let res = undefined
let stepUpdate = undefined
try { try {
res = await fetch('/image', { const lastTask = serverState.task
let renderRequest = undefined
do {
res = await fetch('/render', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify(reqBody) body: JSON.stringify(reqBody)
}) })
renderRequest = await res.json()
// status_code 503, already a task running.
} while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000))
if (typeof renderRequest?.stream !== 'string') {
console.log('Endpoint response: ', renderRequest)
throw new Error('Endpoint response does not contains a response stream url.')
}
task['taskStatusLabel'].innerText = "Busy/Waiting"
do { // Wait for server status to update.
await asyncDelay(250)
if (!isServerAvailable()) {
throw new Error('Connexion with server lost.')
}
} while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task)
if (serverState.session !== 'pending' && serverState.session !== 'running') {
throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
}
do { // Wait for task to start on server.
await asyncDelay(1500)
} while (serverState?.session === 'pending')
// Task started!
res = await fetch(renderRequest.stream, {
headers: {
'Content-Type': 'application/json'
},
})
task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].classList.add('activeTaskLabel')
task['taskStatusLabel'].classList.remove('waitingTaskLabel')
let stepUpdate = undefined
let reader = res.body.getReader() let reader = res.body.getReader()
let textDecoder = new TextDecoder() let textDecoder = new TextDecoder()
let finalJSON = '' let finalJSON = ''
let prevTime = -1 let prevTime = -1
let readComplete = false let readComplete = false
while (true) { while (!readComplete || finalJSON.length > 0) {
let t = new Date().getTime() let t = Date.now()
let jsonStr = '' let jsonStr = ''
if (!readComplete) { if (!readComplete) {
const {value, done} = await reader.read() const {value, done} = await reader.read()
if (done) { if (done) {
readComplete = true readComplete = true
} }
if (done && finalJSON.length <= 0 && !value) {
break
}
if (value) { if (value) {
jsonStr = textDecoder.decode(value) jsonStr = textDecoder.decode(value)
} }
} }
stepUpdate = undefined
try { try {
// hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot. // hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot.
// this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste... // this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste...
@ -537,9 +608,6 @@ async function doMakeImage(task) {
throw e throw e
} }
} }
if (readComplete && finalJSON.length <= 0) {
break
}
if (typeof stepUpdate === 'object' && 'step' in stepUpdate) { if (typeof stepUpdate === 'object' && 'step' in stepUpdate) {
let batchSize = stepUpdate.total_steps let batchSize = stepUpdate.total_steps
let overallStepCount = stepUpdate.step + task.batchesDone * batchSize let overallStepCount = stepUpdate.step + task.batchesDone * batchSize
@ -564,6 +632,23 @@ async function doMakeImage(task) {
showImages(reqBody, stepUpdate, outputContainer, true) showImages(reqBody, stepUpdate, outputContainer, true)
} }
} }
if (stepUpdate?.status) {
break
}
if (readComplete && finalJSON.length <= 0) {
if (res.status === 200) {
await asyncDelay(5000)
res = await fetch(renderRequest.stream, {
headers: {
'Content-Type': 'application/json'
},
})
reader = res.body.getReader()
readComplete = false
} else {
console.log('Stream stopped: ', res)
}
}
prevTime = t prevTime = t
} }
@ -586,21 +671,22 @@ async function doMakeImage(task) {
return false return false
} }
if (typeof stepUpdate !== 'object' || !res || res.status != 200) { if (typeof stepUpdate !== 'object' || !res || res.status != 200) {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg) logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg)
} else if (typeof res === 'object') { } else if (typeof res === 'object') {
let msg = 'Stable Diffusion had an error reading the response: ' let msg = 'Stable Diffusion had an error reading the response: '
try { // 'Response': body stream already read try { // 'Response': body stream already read
msg += 'Read: ' + await res.text() msg += 'Read: ' + await res.text()
} catch(e) { } catch(e) {
msg += 'No error response. ' msg += 'Unexpected end of stream. '
} }
if (finalJSON) { if (finalJSON) {
msg += 'Buffered data: ' + finalJSON msg += 'Buffered data: ' + finalJSON
} }
logError(msg, res, outputMsg) logError(msg, res, outputMsg)
} else { } else {
msg = `Unexpected Read Error:<br/><pre>Response:${res}<br/>StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>` let msg = `Unexpected Read Error:<br/><pre>Response: ${res}<br/>StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
logError(msg, res, outputMsg)
} }
progressBar.style.display = 'none' progressBar.style.display = 'none'
return false return false
@ -654,8 +740,8 @@ async function checkTasks() {
task.isProcessing = true task.isProcessing = true
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop' task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing" task['taskStatusLabel'].innerText = "Starting"
task['taskStatusLabel'].className += " activeTaskLabel" task['taskStatusLabel'].classList.add('waitingTaskLabel')
const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1)) const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1))
const startSeed = task.reqBody.seed || task.seed const startSeed = task.reqBody.seed || task.seed
@ -780,8 +866,8 @@ function getCurrentUserRequest() {
} }
function makeImage() { function makeImage() {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
alert('The server is still starting up..') alert('The server is not available.')
return return
} }
const taskTemplate = getCurrentUserRequest() const taskTemplate = getCurrentUserRequest()
@ -834,7 +920,7 @@ function createTask(task) {
if (task['isProcessing']) { if (task['isProcessing']) {
task.isProcessing = false task.isProcessing = false
try { try {
let res = await fetch('/image/stop') let res = await fetch('/image/stop?session_id=' + sessionId)
} catch (e) { } catch (e) {
console.log(e) console.log(e)
} }
@ -1094,9 +1180,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength() updatePromptStrength()
useBetaChannelField.addEventListener('click', async function(e) { useBetaChannelField.addEventListener('click', async function(e) {
if (serverStatus !== 'online') { if (!isServerAvailable()) {
// logError('The server is still starting up..') // logError('The server is still starting up..')
alert('The server is still starting up..') alert('The server is not available.')
e.preventDefault() e.preventDefault()
return false return false
} }

View File

@ -14,28 +14,55 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
from sd_internal import Request, Response from sd_internal import Request, Response
app = FastAPI() app = FastAPI()
model_loaded = False
model_is_loading = False
modifiers_cache = None modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# don't show access log entries for URLs that start with the given prefix # don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails'] ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
class Symbol(metaclass=SymbolClass): pass
class ServerStates:
class Init(Symbol): pass
class LoadingModel(Symbol): pass
class Online(Symbol): pass
class Rendering(Symbol): pass
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
tasks_queue = queue.Queue()
# defaults from https://huggingface.co/blog/stable_diffusion # defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel): class ImageRequest(BaseModel):
session_id: str = "session" session_id: str = "session"
@ -65,38 +92,152 @@ class ImageRequest(BaseModel):
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
for key in self._base:
ttl, _ = self._base[key]
if self._is_expired(ttl):
del self._base[key]
def clear(self) -> None:
self._base.clear()
def delete(self, key: Hashable) -> bool:
if key not in self._base:
return False
del self._base[key]
return True
def keep(self, key: Hashable, ttl: int) -> bool:
if key in self._base:
_, value = self._base.get(key)
self._base[key] = (self._get_ttl_time(ttl), value)
return True
return False
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
try:
self._base[key] = (
self._get_ttl_time(ttl), value
)
except Exception:
return False
return True
def tryGet(self, key: Hashable) -> Any:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
self.delete(key)
return None
return value
task_cache = TaskCache()
class SetAppConfigRequest(BaseModel): class SetAppConfigRequest(BaseModel):
update_branch: str = "main" update_branch: str = "main"
@app.get('/') @app.get('/')
def read_root(): def read_root():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
@app.get('/ping')
async def ping():
global model_loaded, model_is_loading
def preload_model(file_path=None):
global current_state, current_state_error, current_model_path
if file_path == None:
file_path = get_initial_model_to_load()
if file_path == current_model_path:
return
current_state = ServerStates.LoadingModel
try: try:
if model_loaded:
return {'OK'}
if model_is_loading:
return {'ERROR'}
model_is_loading = True
from sd_internal import runtime from sd_internal import runtime
runtime.load_model_ckpt(ckpt_to_use=file_path)
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load()) current_model_path = file_path
current_state_error = None
model_loaded = True current_state = ServerStates.Online
model_is_loading = False
return {'OK'}
except Exception as e: except Exception as e:
current_model_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc()) print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def thread_render():
global current_state, current_state_error
from sd_internal import runtime
current_state = ServerStates.Online
preload_model()
while True:
task_cache.clean()
task = None
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
try:
task = tasks_queue.get(timeout=1)
except queue.Empty as e:
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
else: continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error:
task.error = current_state_error
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
current_state = ServerStates.Rendering
try:
task.lock.acquire(blocking=False)
res = runtime.mk_img(task.request)
except Exception as e:
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
task_cache.keep(task.request.session_id, TASK_TTL)
# Task completed
task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
# Start Rendering Thread
render_thread = threading.Thread(target=thread_render)
render_thread.daemon = True
render_thread.start()
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
# needs to support the legacy installations # needs to support the legacy installations
def get_initial_model_to_load(): def get_initial_model_to_load():
@ -126,7 +267,6 @@ def resolve_model_to_use(model_name):
model_path = legacy_model_path model_path = legacy_model_path
else: else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path return model_path
def save_model_to_config(model_name): def save_model_to_config(model_name):
@ -135,13 +275,42 @@ def save_model_to_config(model_name):
config['model'] = {} config['model'] = {}
config['model']['stable-diffusion'] = model_name config['model']['stable-diffusion'] = model_name
setConfig(config) setConfig(config)
@app.post('/image') @app.get('/ping') # Get server and optionally session status.
def image(req : ImageRequest): def ping(session_id:str=None):
from sd_internal import runtime if current_state_error or not render_thread.is_alive(): # Render thread is dead.
return HTTPException(status_code=500, detail=str(current_state_error))
# Alive
response = {'status': str(current_state)}
if session_id:
task = task_cache.tryGet(session_id)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
@app.post('/render')
def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead
return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one.
return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
#
from sd_internal import runtime
r = Request() r = Request()
r.session_id = req.session_id r.session_id = req.session_id
r.prompt = req.prompt r.prompt = req.prompt
@ -173,45 +342,70 @@ def image(req : ImageRequest):
save_model_to_config(req.use_stable_diffusion_model) save_model_to_config(req.use_stable_diffusion_model)
try:
if not req.stream_progress_updates: if not req.stream_progress_updates:
r.stream_image_progress = False r.stream_image_progress = False
res = runtime.mk_img(r) new_task = RenderTask(r)
task_cache.put(r.session_id, new_task, TASK_TTL)
tasks_queue.put(new_task)
if req.stream_progress_updates: response = {
return StreamingResponse(res, media_type='application/json') 'status': str(current_state),
else: # compatibility mode: buffer the streaming responses, and return the last one 'queue': tasks_queue.qsize(),
last_result = None 'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
for result in res: async def read_data_generator(data:queue.Queue, lock:threading.Lock):
last_result = result try:
while not data.empty():
res = data.get(block=False)
data.task_done()
yield res
except queue.Empty as e: yield
return json.loads(last_result) @app.get('/image/stream/{session_id:str}/{task_id:int}')
except Exception as e: def stream(session_id:str, task_id:int):
print(traceback.format_exc()) #TODO Move to WebSockets ??
return HTTPException(status_code=500, detail=str(e)) task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json')
@app.get('/image/stop') @app.get('/image/stop')
def stop(): def stop(session_id:str=None):
try: if not session_id:
if model_is_loading: if current_state == ServerStates.Online or current_state == ServerStates.Unavailable:
return {'ERROR'} return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
global current_state_error
from sd_internal import runtime current_state_error = StopAsyncIteration()
runtime.stop_processing = True return {'OK'}
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'} return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/image/tmp/{session_id}/{img_id}') @app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id): def get_image(session_id, img_id):
from sd_internal import runtime task = task_cache.tryGet(session_id)
buf = runtime.temp_images[session_id + '/' + img_id] if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
buf.seek(0) if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
return StreamingResponse(buf, media_type='image/jpeg') try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config') @app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
@ -257,7 +451,6 @@ def getConfig(default_val={}):
def setConfig(config): def setConfig(config):
try: try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f: with open(config_json_path, 'w') as f:
return json.dump(config, f) return json.dump(config, f)
except: except:
@ -316,9 +509,7 @@ class LogSuppressFilter(logging.Filter):
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1: if path.find(prefix) != -1:
return False return False
return True return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# start the browser ui # start the browser ui