Render queue first draft

This commit is contained in:
Marc-Andre Ferland 2022-10-14 03:47:25 -04:00
parent 4b88cfa51a
commit 1ec9d986bb
2 changed files with 393 additions and 116 deletions

View File

@ -117,7 +117,7 @@ maskResetButton.innerHTML = 'Clear'
maskResetButton.style.fontWeight = 'normal'
maskResetButton.style.fontSize = '10pt'
let serverStatus = 'offline'
let serverState = {'status': 'Offline'}
let activeTags = []
let modifiers = []
let lastPromptUsed = ''
@ -212,21 +212,38 @@ function getOutputFormat() {
}
function setStatus(statusType, msg, msgType) {
if (statusType !== 'server') {
return
}
}
if (msgType == 'error') {
// msg = '<span style="color: red">' + msg + '<span>'
serverStatusColor.style.color = 'red'
serverStatusMsg.style.color = 'red'
serverStatusMsg.innerText = 'Stable Diffusion has stopped'
} else if (msgType == 'success') {
// msg = '<span style="color: green">' + msg + '<span>'
serverStatusColor.style.color = 'green'
serverStatusMsg.style.color = 'green'
serverStatusMsg.innerText = 'Stable Diffusion is ready'
serverStatus = 'online'
function setServerStatus(msgType, msg) {
switch(msgType) {
case 'online':
serverStatusColor.style.color = 'green'
serverStatusMsg.style.color = 'green'
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
break
case 'busy':
serverStatusColor.style.color = 'yellow'
serverStatusMsg.style.color = 'yellow'
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
break
case 'error':
serverStatusColor.style.color = 'red'
serverStatusMsg.style.color = 'red'
serverStatusMsg.innerText = 'Stable Diffusion has stopped'
break
}
}
function isServerAvailable() {
if (typeof serverState !== 'object') {
return false
}
switch (serverState.status) {
case 'LoadingModel':
case 'Rendering':
case 'Online':
return true
default:
return false
}
}
@ -250,6 +267,11 @@ function logError(msg, res, outputMsg) {
console.log('request error', res)
setStatus('request', 'error', 'error')
}
function asyncDelay(timeout) {
return new Promise(function(resolve, reject) {
setTimeout(resolve, timeout, true)
})
}
function playSound() {
const audio = new Audio('/media/ding.mp3')
@ -259,16 +281,34 @@ function playSound() {
async function healthCheck() {
try {
let res = await fetch('/ping')
res = await res.json()
if (res[0] == 'OK') {
setStatus('server', 'online', 'success')
let res = undefined
if (sessionId) {
res = await fetch('/ping?session_id=' + sessionId)
} else {
setStatus('server', 'offline', 'error')
res = await fetch('/ping')
}
serverState = await res.json()
// Set status
switch(serverState.status) {
case 'Init':
// Wait for init to complete before updating status.
break
case 'Online':
setServerStatus('online', 'ready')
break
case 'LoadingModel':
setServerStatus('busy', 'loading model')
break
case 'Rendering':
setServerStatus('busy', 'rendering')
break
default: // Unavailable
setServerStatus('error', serverState.status.toLowerCase())
break
}
serverState.time = Date.now()
} catch (e) {
setStatus('server', 'offline', 'error')
setServerStatus('error', 'offline')
}
}
function resizeInpaintingEditor() {
@ -311,7 +351,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
if(typeof res != 'object') return
res.output.reverse()
res.output.forEach((result, index) => {
const imageData = result?.data || result?.path + '?t=' + new Date().getTime()
const imageData = result?.data || result?.path + '?t=' + Date.now()
const imageWidth = reqBody.width
const imageHeight = reqBody.height
if (!imageData.includes('/')) {
@ -409,8 +449,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) {
}
function getStartNewTaskHandler(reqBody, imageItemElem, mode) {
return function() {
if (serverStatus !== 'online') {
alert('The server is still starting up..')
if (!isServerAvailable()) {
alert('The server is not available.')
return
}
const imageElem = imageItemElem.querySelector('img')
@ -473,37 +513,68 @@ async function doMakeImage(task) {
const progressBar = task['progressBar']
let res = undefined
let stepUpdate = undefined
try {
res = await fetch('/image', {
method: 'POST',
const lastTask = serverState.task
let renderRequest = undefined
do {
res = await fetch('/render', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(reqBody)
})
renderRequest = await res.json()
// status_code 503, already a task running.
} while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000))
if (typeof renderRequest?.stream !== 'string') {
console.log('Endpoint response: ', renderRequest)
throw new Error('Endpoint response does not contains a response stream url.')
}
task['taskStatusLabel'].innerText = "Busy/Waiting"
do { // Wait for server status to update.
await asyncDelay(250)
if (!isServerAvailable()) {
throw new Error('Connexion with server lost.')
}
} while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task)
if (serverState.session !== 'pending' && serverState.session !== 'running') {
throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
}
do { // Wait for task to start on server.
await asyncDelay(1500)
} while (serverState?.session === 'pending')
// Task started!
res = await fetch(renderRequest.stream, {
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify(reqBody)
})
task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].classList.add('activeTaskLabel')
task['taskStatusLabel'].classList.remove('waitingTaskLabel')
let stepUpdate = undefined
let reader = res.body.getReader()
let textDecoder = new TextDecoder()
let finalJSON = ''
let prevTime = -1
let readComplete = false
while (true) {
let t = new Date().getTime()
while (!readComplete || finalJSON.length > 0) {
let t = Date.now()
let jsonStr = ''
if (!readComplete) {
const {value, done} = await reader.read()
if (done) {
readComplete = true
}
if (done && finalJSON.length <= 0 && !value) {
break
}
if (value) {
jsonStr = textDecoder.decode(value)
}
}
stepUpdate = undefined
try {
// hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot.
// this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste...
@ -537,9 +608,6 @@ async function doMakeImage(task) {
throw e
}
}
if (readComplete && finalJSON.length <= 0) {
break
}
if (typeof stepUpdate === 'object' && 'step' in stepUpdate) {
let batchSize = stepUpdate.total_steps
let overallStepCount = stepUpdate.step + task.batchesDone * batchSize
@ -564,6 +632,23 @@ async function doMakeImage(task) {
showImages(reqBody, stepUpdate, outputContainer, true)
}
}
if (stepUpdate?.status) {
break
}
if (readComplete && finalJSON.length <= 0) {
if (res.status === 200) {
await asyncDelay(5000)
res = await fetch(renderRequest.stream, {
headers: {
'Content-Type': 'application/json'
},
})
reader = res.body.getReader()
readComplete = false
} else {
console.log('Stream stopped: ', res)
}
}
prevTime = t
}
@ -580,27 +665,28 @@ async function doMakeImage(task) {
3. Try generating a smaller image.<br/>`
}
} else {
msg = `Unexpected Read Error:<br/><pre>StepUpdate:${JSON.stringify(stepUpdate, undefined, 4)}</pre>`
msg = `Unexpected Read Error:<br/><pre>StepUpdate: ${JSON.stringify(stepUpdate, undefined, 4)}</pre>`
}
logError(msg, res, outputMsg)
return false
}
if (typeof stepUpdate !== 'object' || !res || res.status != 200) {
if (serverStatus !== 'online') {
if (!isServerAvailable()) {
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg)
} else if (typeof res === 'object') {
let msg = 'Stable Diffusion had an error reading the response: '
try { // 'Response': body stream already read
msg += 'Read: ' + await res.text()
} catch(e) {
msg += 'No error response. '
msg += 'Unexpected end of stream. '
}
if (finalJSON) {
msg += 'Buffered data: ' + finalJSON
}
logError(msg, res, outputMsg)
} else {
msg = `Unexpected Read Error:<br/><pre>Response:${res}<br/>StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
let msg = `Unexpected Read Error:<br/><pre>Response: ${res}<br/>StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
logError(msg, res, outputMsg)
}
progressBar.style.display = 'none'
return false
@ -654,8 +740,8 @@ async function checkTasks() {
task.isProcessing = true
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
task['taskStatusLabel'].innerText = "Processing"
task['taskStatusLabel'].className += " activeTaskLabel"
task['taskStatusLabel'].innerText = "Starting"
task['taskStatusLabel'].classList.add('waitingTaskLabel')
const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1))
const startSeed = task.reqBody.seed || task.seed
@ -780,8 +866,8 @@ function getCurrentUserRequest() {
}
function makeImage() {
if (serverStatus !== 'online') {
alert('The server is still starting up..')
if (!isServerAvailable()) {
alert('The server is not available.')
return
}
const taskTemplate = getCurrentUserRequest()
@ -834,7 +920,7 @@ function createTask(task) {
if (task['isProcessing']) {
task.isProcessing = false
try {
let res = await fetch('/image/stop')
let res = await fetch('/image/stop?session_id=' + sessionId)
} catch (e) {
console.log(e)
}
@ -1094,9 +1180,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
updatePromptStrength()
useBetaChannelField.addEventListener('click', async function(e) {
if (serverStatus !== 'online') {
if (!isServerAvailable()) {
// logError('The server is still starting up..')
alert('The server is still starting up..')
alert('The server is not available.')
e.preventDefault()
return false
}

View File

@ -14,28 +14,55 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, StreamingResponse
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
from sd_internal import Request, Response
app = FastAPI()
model_loaded = False
model_is_loading = False
modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails']
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
class Symbol(metaclass=SymbolClass): pass
class ServerStates:
class Init(Symbol): pass
class LoadingModel(Symbol): pass
class Online(Symbol): pass
class Rendering(Symbol): pass
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
tasks_queue = queue.Queue()
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
@ -65,38 +92,152 @@ class ImageRequest(BaseModel):
stream_progress_updates: bool = False
stream_image_progress: bool = False
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
for key in self._base:
ttl, _ = self._base[key]
if self._is_expired(ttl):
del self._base[key]
def clear(self) -> None:
self._base.clear()
def delete(self, key: Hashable) -> bool:
if key not in self._base:
return False
del self._base[key]
return True
def keep(self, key: Hashable, ttl: int) -> bool:
if key in self._base:
_, value = self._base.get(key)
self._base[key] = (self._get_ttl_time(ttl), value)
return True
return False
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
try:
self._base[key] = (
self._get_ttl_time(ttl), value
)
except Exception:
return False
return True
def tryGet(self, key: Hashable) -> Any:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
self.delete(key)
return None
return value
task_cache = TaskCache()
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
@app.get('/')
def read_root():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
@app.get('/ping')
async def ping():
global model_loaded, model_is_loading
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
def preload_model(file_path=None):
global current_state, current_state_error, current_model_path
if file_path == None:
file_path = get_initial_model_to_load()
if file_path == current_model_path:
return
current_state = ServerStates.LoadingModel
try:
if model_loaded:
return {'OK'}
if model_is_loading:
return {'ERROR'}
model_is_loading = True
from sd_internal import runtime
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
model_loaded = True
model_is_loading = False
return {'OK'}
runtime.load_model_ckpt(ckpt_to_use=file_path)
current_model_path = file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def thread_render():
global current_state, current_state_error
from sd_internal import runtime
current_state = ServerStates.Online
preload_model()
while True:
task_cache.clean()
task = None
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
try:
task = tasks_queue.get(timeout=1)
except queue.Empty as e:
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
else: continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error:
task.error = current_state_error
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
current_state = ServerStates.Rendering
try:
task.lock.acquire(blocking=False)
res = runtime.mk_img(task.request)
except Exception as e:
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
task_cache.keep(task.request.session_id, TASK_TTL)
# Task completed
task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
# Start Rendering Thread
render_thread = threading.Thread(target=thread_render)
render_thread.daemon = True
render_thread.start()
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
# needs to support the legacy installations
def get_initial_model_to_load():
@ -126,7 +267,6 @@ def resolve_model_to_use(model_name):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
def save_model_to_config(model_name):
@ -135,13 +275,42 @@ def save_model_to_config(model_name):
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/image')
def image(req : ImageRequest):
from sd_internal import runtime
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if current_state_error or not render_thread.is_alive(): # Render thread is dead.
return HTTPException(status_code=500, detail=str(current_state_error))
# Alive
response = {'status': str(current_state)}
if session_id:
task = task_cache.tryGet(session_id)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
@app.post('/render')
def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead
return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one.
return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
#
from sd_internal import runtime
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
@ -173,45 +342,70 @@ def image(req : ImageRequest):
save_model_to_config(req.use_stable_diffusion_model)
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
task_cache.put(r.session_id, new_task, TASK_TTL)
tasks_queue.put(new_task)
response = {
'status': str(current_state),
'queue': tasks_queue.qsize(),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
async def read_data_generator(data:queue.Queue, lock:threading.Lock):
try:
if not req.stream_progress_updates:
r.stream_image_progress = False
while not data.empty():
res = data.get(block=False)
data.task_done()
yield res
except queue.Empty as e: yield
res = runtime.mk_img(r)
if req.stream_progress_updates:
return StreamingResponse(res, media_type='application/json')
else: # compatibility mode: buffer the streaming responses, and return the last one
last_result = None
for result in res:
last_result = result
return json.loads(last_result)
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json')
@app.get('/image/stop')
def stop():
try:
if model_is_loading:
return {'ERROR'}
from sd_internal import runtime
runtime.stop_processing = True
def stop(session_id:str=None):
if not session_id:
if current_state == ServerStates.Online or current_state == ServerStates.Unavailable:
return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
global current_state_error
current_state_error = StopAsyncIteration()
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id}')
@app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id):
from sd_internal import runtime
buf = runtime.temp_images[session_id + '/' + img_id]
buf.seek(0)
return StreamingResponse(buf, media_type='image/jpeg')
task = task_cache.tryGet(session_id)
if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
@ -257,7 +451,6 @@ def getConfig(default_val={}):
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
@ -316,9 +509,7 @@ class LogSuppressFilter(logging.Filter):
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# start the browser ui