diff --git a/ui/media/main.js b/ui/media/main.js index 0d7757d7..2d8fd6cd 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -117,7 +117,7 @@ maskResetButton.innerHTML = 'Clear' maskResetButton.style.fontWeight = 'normal' maskResetButton.style.fontSize = '10pt' -let serverStatus = 'offline' +let serverState = {'status': 'Offline'} let activeTags = [] let modifiers = [] let lastPromptUsed = '' @@ -212,21 +212,38 @@ function getOutputFormat() { } function setStatus(statusType, msg, msgType) { - if (statusType !== 'server') { - return - } +} - if (msgType == 'error') { - // msg = '' + msg + '' - serverStatusColor.style.color = 'red' - serverStatusMsg.style.color = 'red' - serverStatusMsg.innerText = 'Stable Diffusion has stopped' - } else if (msgType == 'success') { - // msg = '' + msg + '' - serverStatusColor.style.color = 'green' - serverStatusMsg.style.color = 'green' - serverStatusMsg.innerText = 'Stable Diffusion is ready' - serverStatus = 'online' +function setServerStatus(msgType, msg) { + switch(msgType) { + case 'online': + serverStatusColor.style.color = 'green' + serverStatusMsg.style.color = 'green' + serverStatusMsg.innerText = 'Stable Diffusion is ' + msg + break + case 'busy': + serverStatusColor.style.color = 'yellow' + serverStatusMsg.style.color = 'yellow' + serverStatusMsg.innerText = 'Stable Diffusion is ' + msg + break + case 'error': + serverStatusColor.style.color = 'red' + serverStatusMsg.style.color = 'red' + serverStatusMsg.innerText = 'Stable Diffusion has stopped' + break + } +} +function isServerAvailable() { + if (typeof serverState !== 'object') { + return false + } + switch (serverState.status) { + case 'LoadingModel': + case 'Rendering': + case 'Online': + return true + default: + return false } } @@ -250,6 +267,11 @@ function logError(msg, res, outputMsg) { console.log('request error', res) setStatus('request', 'error', 'error') } +function asyncDelay(timeout) { + return new Promise(function(resolve, reject) { + setTimeout(resolve, timeout, true) + }) +} function playSound() { const audio = new Audio('/media/ding.mp3') @@ -259,16 +281,34 @@ function playSound() { async function healthCheck() { try { - let res = await fetch('/ping') - res = await res.json() - - if (res[0] == 'OK') { - setStatus('server', 'online', 'success') + let res = undefined + if (sessionId) { + res = await fetch('/ping?session_id=' + sessionId) } else { - setStatus('server', 'offline', 'error') + res = await fetch('/ping') } + serverState = await res.json() + // Set status + switch(serverState.status) { + case 'Init': + // Wait for init to complete before updating status. + break + case 'Online': + setServerStatus('online', 'ready') + break + case 'LoadingModel': + setServerStatus('busy', 'loading model') + break + case 'Rendering': + setServerStatus('busy', 'rendering') + break + default: // Unavailable + setServerStatus('error', serverState.status.toLowerCase()) + break + } + serverState.time = Date.now() } catch (e) { - setStatus('server', 'offline', 'error') + setServerStatus('error', 'offline') } } function resizeInpaintingEditor() { @@ -311,7 +351,7 @@ function showImages(reqBody, res, outputContainer, livePreview) { if(typeof res != 'object') return res.output.reverse() res.output.forEach((result, index) => { - const imageData = result?.data || result?.path + '?t=' + new Date().getTime() + const imageData = result?.data || result?.path + '?t=' + Date.now() const imageWidth = reqBody.width const imageHeight = reqBody.height if (!imageData.includes('/')) { @@ -409,8 +449,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) { } function getStartNewTaskHandler(reqBody, imageItemElem, mode) { return function() { - if (serverStatus !== 'online') { - alert('The server is still starting up..') + if (!isServerAvailable()) { + alert('The server is not available.') return } const imageElem = imageItemElem.querySelector('img') @@ -473,37 +513,68 @@ async function doMakeImage(task) { const progressBar = task['progressBar'] let res = undefined - let stepUpdate = undefined try { - res = await fetch('/image', { - method: 'POST', + const lastTask = serverState.task + let renderRequest = undefined + do { + res = await fetch('/render', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(reqBody) + }) + renderRequest = await res.json() + // status_code 503, already a task running. + } while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000)) + if (typeof renderRequest?.stream !== 'string') { + console.log('Endpoint response: ', renderRequest) + throw new Error('Endpoint response does not contains a response stream url.') + } + task['taskStatusLabel'].innerText = "Busy/Waiting" + do { // Wait for server status to update. + await asyncDelay(250) + if (!isServerAvailable()) { + throw new Error('Connexion with server lost.') + } + } while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task) + if (serverState.session !== 'pending' && serverState.session !== 'running') { + throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined') + } + do { // Wait for task to start on server. + await asyncDelay(1500) + } while (serverState?.session === 'pending') + + // Task started! + res = await fetch(renderRequest.stream, { headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(reqBody) }) + task['taskStatusLabel'].innerText = "Processing" + task['taskStatusLabel'].classList.add('activeTaskLabel') + task['taskStatusLabel'].classList.remove('waitingTaskLabel') + + let stepUpdate = undefined let reader = res.body.getReader() let textDecoder = new TextDecoder() let finalJSON = '' let prevTime = -1 let readComplete = false - while (true) { - let t = new Date().getTime() - + while (!readComplete || finalJSON.length > 0) { + let t = Date.now() let jsonStr = '' if (!readComplete) { const {value, done} = await reader.read() if (done) { readComplete = true } - if (done && finalJSON.length <= 0 && !value) { - break - } if (value) { jsonStr = textDecoder.decode(value) } } + stepUpdate = undefined try { // hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot. // this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste... @@ -537,9 +608,6 @@ async function doMakeImage(task) { throw e } } - if (readComplete && finalJSON.length <= 0) { - break - } if (typeof stepUpdate === 'object' && 'step' in stepUpdate) { let batchSize = stepUpdate.total_steps let overallStepCount = stepUpdate.step + task.batchesDone * batchSize @@ -564,6 +632,23 @@ async function doMakeImage(task) { showImages(reqBody, stepUpdate, outputContainer, true) } } + if (stepUpdate?.status) { + break + } + if (readComplete && finalJSON.length <= 0) { + if (res.status === 200) { + await asyncDelay(5000) + res = await fetch(renderRequest.stream, { + headers: { + 'Content-Type': 'application/json' + }, + }) + reader = res.body.getReader() + readComplete = false + } else { + console.log('Stream stopped: ', res) + } + } prevTime = t } @@ -580,27 +665,28 @@ async function doMakeImage(task) { 3. Try generating a smaller image.
` } } else { - msg = `Unexpected Read Error:
StepUpdate:${JSON.stringify(stepUpdate, undefined, 4)}
` + msg = `Unexpected Read Error:
StepUpdate: ${JSON.stringify(stepUpdate, undefined, 4)}
` } logError(msg, res, outputMsg) return false } if (typeof stepUpdate !== 'object' || !res || res.status != 200) { - if (serverStatus !== 'online') { + if (!isServerAvailable()) { logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg) } else if (typeof res === 'object') { let msg = 'Stable Diffusion had an error reading the response: ' try { // 'Response': body stream already read msg += 'Read: ' + await res.text() } catch(e) { - msg += 'No error response. ' + msg += 'Unexpected end of stream. ' } if (finalJSON) { msg += 'Buffered data: ' + finalJSON } logError(msg, res, outputMsg) } else { - msg = `Unexpected Read Error:
Response:${res}
StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}
` + let msg = `Unexpected Read Error:
Response: ${res}
StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}
` + logError(msg, res, outputMsg) } progressBar.style.display = 'none' return false @@ -654,8 +740,8 @@ async function checkTasks() { task.isProcessing = true task['stopTask'].innerHTML = ' Stop' - task['taskStatusLabel'].innerText = "Processing" - task['taskStatusLabel'].className += " activeTaskLabel" + task['taskStatusLabel'].innerText = "Starting" + task['taskStatusLabel'].classList.add('waitingTaskLabel') const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1)) const startSeed = task.reqBody.seed || task.seed @@ -780,8 +866,8 @@ function getCurrentUserRequest() { } function makeImage() { - if (serverStatus !== 'online') { - alert('The server is still starting up..') + if (!isServerAvailable()) { + alert('The server is not available.') return } const taskTemplate = getCurrentUserRequest() @@ -834,7 +920,7 @@ function createTask(task) { if (task['isProcessing']) { task.isProcessing = false try { - let res = await fetch('/image/stop') + let res = await fetch('/image/stop?session_id=' + sessionId) } catch (e) { console.log(e) } @@ -1094,9 +1180,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider) updatePromptStrength() useBetaChannelField.addEventListener('click', async function(e) { - if (serverStatus !== 'online') { + if (!isServerAvailable()) { // logError('The server is still starting up..') - alert('The server is still starting up..') + alert('The server is not available.') e.preventDefault() return false } diff --git a/ui/server.py b/ui/server.py index 41bcaee9..3414cf66 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,28 +14,55 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles -from starlette.responses import FileResponse, StreamingResponse +from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel import logging +import queue, threading, time +from typing import Any, Generator, Hashable, Optional, Union from sd_internal import Request, Response app = FastAPI() -model_loaded = False -model_is_loading = False - modifiers_cache = None outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) # don't show access log entries for URLs that start with the given prefix -ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails'] +ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] +NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): return self.__qualname__ + def __str__(self): return self.__name__ +class Symbol(metaclass=SymbolClass): pass + +class ServerStates: + class Init(Symbol): pass + class LoadingModel(Symbol): pass + class Online(Symbol): pass + class Rendering(Symbol): pass + class Unavailable(Symbol): pass + +class RenderTask(): # Task with output queue and completion lock. + def __init__(self, req: Request): + self.request: Request = req # Initial Request + self.response: Any = None # Copy of the last reponse + self.temp_images:[] = [None] * req.num_outputs + self.error: Exception = None + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + +current_state = ServerStates.Init +current_state_error:Exception = None +current_model_path = None +tasks_queue = queue.Queue() + # defaults from https://huggingface.co/blog/stable_diffusion class ImageRequest(BaseModel): session_id: str = "session" @@ -65,38 +92,152 @@ class ImageRequest(BaseModel): stream_progress_updates: bool = False stream_image_progress: bool = False +# Temporary cache to allow to query tasks results for a short time after they are completed. +class TaskCache(): + def __init__(self): + self._base = dict() + def _get_ttl_time(self, ttl: int) -> int: + return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: + return int(time.time()) >= timestamp + def clean(self) -> None: + for key in self._base: + ttl, _ = self._base[key] + if self._is_expired(ttl): + del self._base[key] + def clear(self) -> None: + self._base.clear() + def delete(self, key: Hashable) -> bool: + if key not in self._base: + return False + del self._base[key] + return True + def keep(self, key: Hashable, ttl: int) -> bool: + if key in self._base: + _, value = self._base.get(key) + self._base[key] = (self._get_ttl_time(ttl), value) + return True + return False + def put(self, key: Hashable, value: Any, ttl: int) -> bool: + try: + self._base[key] = ( + self._get_ttl_time(ttl), value + ) + except Exception: + return False + return True + def tryGet(self, key: Hashable) -> Any: + ttl, value = self._base.get(key, (None, None)) + if ttl is not None and self._is_expired(ttl): + self.delete(key) + return None + return value + +task_cache = TaskCache() + class SetAppConfigRequest(BaseModel): update_branch: str = "main" @app.get('/') def read_root(): - headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers) - -@app.get('/ping') -async def ping(): - global model_loaded, model_is_loading + return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) +def preload_model(file_path=None): + global current_state, current_state_error, current_model_path + if file_path == None: + file_path = get_initial_model_to_load() + if file_path == current_model_path: + return + current_state = ServerStates.LoadingModel try: - if model_loaded: - return {'OK'} - - if model_is_loading: - return {'ERROR'} - - model_is_loading = True - from sd_internal import runtime - - runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load()) - - model_loaded = True - model_is_loading = False - - return {'OK'} + runtime.load_model_ckpt(ckpt_to_use=file_path) + current_model_path = file_path + current_state_error = None + current_state = ServerStates.Online except Exception as e: + current_model_path = None + current_state_error = e + current_state = ServerStates.Unavailable print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) + +def thread_render(): + global current_state, current_state_error + from sd_internal import runtime + current_state = ServerStates.Online + preload_model() + while True: + task_cache.clean() + task = None + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + try: + task = tasks_queue.get(timeout=1) + except queue.Empty as e: + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + else: continue + #if current_model_path != task.request.use_stable_diffusion_model: + # preload_model(task.request.use_stable_diffusion_model) + if current_state_error: + task.error = current_state_error + continue + print(f'Session {task.request.session_id} starting task {id(task)}') + current_state = ServerStates.Rendering + try: + task.lock.acquire(blocking=False) + res = runtime.mk_img(task.request) + except Exception as e: + task.error = e + task.lock.release() + tasks_queue.task_done() + print(traceback.format_exc()) + continue + dataQueue = None + if task.request.stream_progress_updates: + dataQueue = task.buffer_queue + for result in res: + if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + runtime.stop_processing = True + if isinstance(current_state_error, StopAsyncIteration): + task.error = current_state_error + current_state_error = None + print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + if dataQueue: + dataQueue.put(result) + if isinstance(result, str): + result = json.loads(result) + task.response = result + if 'output' in result: + for out_obj in result['output']: + if 'path' in out_obj: + img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] + task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] + elif 'data' in out_obj: + task.temp_images[result['output'].index(out_obj)] = out_obj['data'] + task_cache.keep(task.request.session_id, TASK_TTL) + # Task completed + task.lock.release() + tasks_queue.task_done() + task_cache.keep(task.request.session_id, TASK_TTL) + if isinstance(task.error, StopAsyncIteration): + print(f'Session {task.request.session_id} task {id(task)} cancelled!') + elif task.error is not None: + print(f'Session {task.request.session_id} task {id(task)} failed!') + else: + print(f'Session {task.request.session_id} task {id(task)} completed.') + current_state = ServerStates.Online +# Start Rendering Thread +render_thread = threading.Thread(target=thread_render) +render_thread.daemon = True +render_thread.start() + +@app.on_event("shutdown") +def shutdown_event(): # Signal render thread to close on shutdown + global current_state_error + current_state_error = SystemExit('Application shutting down.') # needs to support the legacy installations def get_initial_model_to_load(): @@ -126,7 +267,6 @@ def resolve_model_to_use(model_name): model_path = legacy_model_path else: model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) - return model_path def save_model_to_config(model_name): @@ -135,13 +275,42 @@ def save_model_to_config(model_name): config['model'] = {} config['model']['stable-diffusion'] = model_name - setConfig(config) -@app.post('/image') -def image(req : ImageRequest): - from sd_internal import runtime +@app.get('/ping') # Get server and optionally session status. +def ping(session_id:str=None): + if current_state_error or not render_thread.is_alive(): # Render thread is dead. + return HTTPException(status_code=500, detail=str(current_state_error)) + # Alive + response = {'status': str(current_state)} + if session_id: + task = task_cache.tryGet(session_id) + if task: + response['task'] = id(task) + if task.lock.locked(): + response['session'] = 'running' + elif isinstance(task.error, StopAsyncIteration): + response['session'] = 'stopped' + elif task.error: + response['session'] = 'error' + elif not task.buffer_queue.empty(): + response['session'] = 'buffer' + elif task.response: + response['session'] = 'completed' + else: + response['session'] = 'pending' + return JSONResponse(response, headers=NOCACHE_HEADERS) +@app.post('/render') +def render(req : ImageRequest): + if not render_thread.is_alive(): # Render thread is dead + return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error + # Alive, check if task in cache + task = task_cache.tryGet(req.session_id) + if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one. + return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable + # + from sd_internal import runtime r = Request() r.session_id = req.session_id r.prompt = req.prompt @@ -173,45 +342,70 @@ def image(req : ImageRequest): save_model_to_config(req.use_stable_diffusion_model) + if not req.stream_progress_updates: + r.stream_image_progress = False + + new_task = RenderTask(r) + task_cache.put(r.session_id, new_task, TASK_TTL) + tasks_queue.put(new_task) + + response = { + 'status': str(current_state), + 'queue': tasks_queue.qsize(), + 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', + 'task': id(new_task) + } + return JSONResponse(response, headers=NOCACHE_HEADERS) + +async def read_data_generator(data:queue.Queue, lock:threading.Lock): try: - if not req.stream_progress_updates: - r.stream_image_progress = False + while not data.empty(): + res = data.get(block=False) + data.task_done() + yield res + except queue.Empty as e: yield - res = runtime.mk_img(r) - - if req.stream_progress_updates: - return StreamingResponse(res, media_type='application/json') - else: # compatibility mode: buffer the streaming responses, and return the last one - last_result = None - - for result in res: - last_result = result - - return json.loads(last_result) - except Exception as e: - print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) +@app.get('/image/stream/{session_id:str}/{task_id:int}') +def stream(session_id:str, task_id:int): + #TODO Move to WebSockets ?? + task = task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone + if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + if task.buffer_queue.empty() and not task.lock.locked(): + if task.response: + #print(f'Session {session_id} sending cached response') + return JSONResponse(task.response, headers=NOCACHE_HEADERS) + return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early + #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json') @app.get('/image/stop') -def stop(): - try: - if model_is_loading: - return {'ERROR'} - - from sd_internal import runtime - runtime.stop_processing = True - +def stop(session_id:str=None): + if not session_id: + if current_state == ServerStates.Online or current_state == ServerStates.Unavailable: + return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict + global current_state_error + current_state_error = StopAsyncIteration() return {'OK'} - except Exception as e: - print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) + task = task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict + task.error = StopAsyncIteration('') + return {'OK'} -@app.get('/image/tmp/{session_id}/{img_id}') +@app.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): - from sd_internal import runtime - buf = runtime.temp_images[session_id + '/' + img_id] - buf.seek(0) - return StreamingResponse(buf, media_type='image/jpeg') + task = task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone + if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early + try: + img_data = task.temp_images[img_id] + if isinstance(img_data, str): + return img_data + img_data.seek(0) + return StreamingResponse(img_data, media_type='image/jpeg') + except KeyError as e: + return HTTPException(status_code=500, detail=str(e)) @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): @@ -257,7 +451,6 @@ def getConfig(default_val={}): def setConfig(config): try: config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w') as f: return json.dump(config, f) except: @@ -316,9 +509,7 @@ class LogSuppressFilter(logging.Filter): for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: if path.find(prefix) != -1: return False - return True - logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) # start the browser ui