From a6e5474fdbf3c097bff6882dc3ad405eb35259ff Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 00:56:04 -0400 Subject: [PATCH 01/20] CSS waitingTaskLabel for task waiting to start --- ui/media/main.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ui/media/main.css b/ui/media/main.css index 1fc21e68..e8d6f342 100644 --- a/ui/media/main.css +++ b/ui/media/main.css @@ -389,6 +389,11 @@ img { border: 1px solid rgb(0, 75, 19); color:rgb(204, 255, 217) } +.waitingTaskLabel { + background:rgb(90, 90, 0); + border: 1px solid rgb(0, 75, 19); + color:rgb(255, 255, 204) +} .secondaryButton { background: rgb(132, 8, 0); border: 1px solid rgb(122, 29, 0); From bc56226a282ac8f7f1e9b3c543bec981947533ce Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 03:42:43 -0400 Subject: [PATCH 02/20] Grouped many endpoints into one --- ui/media/main.js | 11 ++++++----- ui/server.py | 47 ++++++++++++++++++++--------------------------- 2 files changed, 26 insertions(+), 32 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index 0bcb2fc3..e2513703 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -1123,7 +1123,7 @@ useBetaChannelField.addEventListener('click', async function(e) { async function getAppConfig() { try { - let res = await fetch('/app_config') + let res = await fetch('/get?key=app_config') const config = await res.json() if (config.update_branch === 'beta') { @@ -1139,7 +1139,7 @@ async function getAppConfig() { async function getModels() { try { - let res = await fetch('/models') + let res = await fetch('/get?key=models') const models = await res.json() let activeModel = models['active'] @@ -1404,7 +1404,7 @@ async function getDiskPath() { return } - let res = await fetch('/output_dir') + let res = await fetch('/get?key=output_dir') if (res.status === 200) { res = await res.json() res = res[0] @@ -1515,14 +1515,15 @@ function resizeModifierCards(val) { const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix)) card.className = classes.join(' ').trim() - if(val != 0) + if(val != 0) { card.classList.add(cardSize(val)) + } }) } async function loadModifiers() { try { - let res = await fetch('/modifiers.json?v=2') + let res = await fetch('/get?key=modifiers') if (res.status === 200) { res = await res.json() diff --git a/ui/server.py b/ui/server.py index 5d7f1bfd..41bcaee9 100644 --- a/ui/server.py +++ b/ui/server.py @@ -242,31 +242,17 @@ async def setAppConfig(req : SetAppConfigRequest): print(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e)) -@app.get('/app_config') -def getAppConfig(): +def getConfig(default_val={}): try: config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return HTTPException(status_code=500, detail="No config file") - + return default_val with open(config_json_path, 'r') as f: return json.load(f) except Exception as e: + print(str(e)) print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) - -def getConfig(): - try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - - if not os.path.exists(config_json_path): - return {} - - with open(config_json_path, 'r') as f: - return json.load(f) - except Exception as e: - return {} + return default_val def setConfig(config): try: @@ -275,9 +261,9 @@ def setConfig(config): with open(config_json_path, 'w') as f: return json.dump(config, f) except: + print(str(e)) print(traceback.format_exc()) -@app.get('/models') def getModels(): models = { 'active': { @@ -307,14 +293,21 @@ def getModels(): return models -@app.get('/modifiers.json') -def read_modifiers(): - headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=headers) - -@app.get('/output_dir') -def read_home_dir(): - return {outpath} +@app.get('/get') +def read_web_data(key:str=None): + if key is None: # /get without parameters, stable-diffusion easter egg. + return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot + elif key == 'app_config': + config = getConfig(default_val=None) + if config is None: + return HTTPException(status_code=500, detail="Config file is missing or unreadable") + return config + elif key == 'models': + return getModels() + elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) + elif key == 'output_dir': return {outpath} + else: + return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found # don't log certain requests class LogSuppressFilter(logging.Filter): From 4b88cfa51ac0cbf302dfa93d6a93e91574558d76 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 03:43:33 -0400 Subject: [PATCH 03/20] More simple time check --- ui/media/main.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index e2513703..0d7757d7 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -18,7 +18,7 @@ const INPAINTING_EDITOR_SIZE = 450 const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64') -let sessionId = new Date().getTime() +let sessionId = Date.now() let promptField = document.querySelector('#prompt') let promptsFromFileSelector = document.querySelector('#prompt_from_file') @@ -648,7 +648,7 @@ async function checkTasks() { let task = taskQueue.pop() currentTask = task - let time = new Date().getTime() + let time = Date.now() let successCount = 0 @@ -690,7 +690,7 @@ async function checkTasks() { task['stopTask'].innerHTML = ' Remove' task['taskStatusLabel'].style.display = 'none' - time = new Date().getTime() - time + time = Date.now() - time time /= 1000 if (successCount === task.batchCount) { From 1ec9d986bb866d2eee92bb20c2a707091d74197e Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 03:47:25 -0400 Subject: [PATCH 04/20] Render queue first draft --- ui/media/main.js | 184 ++++++++++++++++++++------- ui/server.py | 325 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 393 insertions(+), 116 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index 0d7757d7..2d8fd6cd 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -117,7 +117,7 @@ maskResetButton.innerHTML = 'Clear' maskResetButton.style.fontWeight = 'normal' maskResetButton.style.fontSize = '10pt' -let serverStatus = 'offline' +let serverState = {'status': 'Offline'} let activeTags = [] let modifiers = [] let lastPromptUsed = '' @@ -212,21 +212,38 @@ function getOutputFormat() { } function setStatus(statusType, msg, msgType) { - if (statusType !== 'server') { - return - } +} - if (msgType == 'error') { - // msg = '' + msg + '' - serverStatusColor.style.color = 'red' - serverStatusMsg.style.color = 'red' - serverStatusMsg.innerText = 'Stable Diffusion has stopped' - } else if (msgType == 'success') { - // msg = '' + msg + '' - serverStatusColor.style.color = 'green' - serverStatusMsg.style.color = 'green' - serverStatusMsg.innerText = 'Stable Diffusion is ready' - serverStatus = 'online' +function setServerStatus(msgType, msg) { + switch(msgType) { + case 'online': + serverStatusColor.style.color = 'green' + serverStatusMsg.style.color = 'green' + serverStatusMsg.innerText = 'Stable Diffusion is ' + msg + break + case 'busy': + serverStatusColor.style.color = 'yellow' + serverStatusMsg.style.color = 'yellow' + serverStatusMsg.innerText = 'Stable Diffusion is ' + msg + break + case 'error': + serverStatusColor.style.color = 'red' + serverStatusMsg.style.color = 'red' + serverStatusMsg.innerText = 'Stable Diffusion has stopped' + break + } +} +function isServerAvailable() { + if (typeof serverState !== 'object') { + return false + } + switch (serverState.status) { + case 'LoadingModel': + case 'Rendering': + case 'Online': + return true + default: + return false } } @@ -250,6 +267,11 @@ function logError(msg, res, outputMsg) { console.log('request error', res) setStatus('request', 'error', 'error') } +function asyncDelay(timeout) { + return new Promise(function(resolve, reject) { + setTimeout(resolve, timeout, true) + }) +} function playSound() { const audio = new Audio('/media/ding.mp3') @@ -259,16 +281,34 @@ function playSound() { async function healthCheck() { try { - let res = await fetch('/ping') - res = await res.json() - - if (res[0] == 'OK') { - setStatus('server', 'online', 'success') + let res = undefined + if (sessionId) { + res = await fetch('/ping?session_id=' + sessionId) } else { - setStatus('server', 'offline', 'error') + res = await fetch('/ping') } + serverState = await res.json() + // Set status + switch(serverState.status) { + case 'Init': + // Wait for init to complete before updating status. + break + case 'Online': + setServerStatus('online', 'ready') + break + case 'LoadingModel': + setServerStatus('busy', 'loading model') + break + case 'Rendering': + setServerStatus('busy', 'rendering') + break + default: // Unavailable + setServerStatus('error', serverState.status.toLowerCase()) + break + } + serverState.time = Date.now() } catch (e) { - setStatus('server', 'offline', 'error') + setServerStatus('error', 'offline') } } function resizeInpaintingEditor() { @@ -311,7 +351,7 @@ function showImages(reqBody, res, outputContainer, livePreview) { if(typeof res != 'object') return res.output.reverse() res.output.forEach((result, index) => { - const imageData = result?.data || result?.path + '?t=' + new Date().getTime() + const imageData = result?.data || result?.path + '?t=' + Date.now() const imageWidth = reqBody.width const imageHeight = reqBody.height if (!imageData.includes('/')) { @@ -409,8 +449,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) { } function getStartNewTaskHandler(reqBody, imageItemElem, mode) { return function() { - if (serverStatus !== 'online') { - alert('The server is still starting up..') + if (!isServerAvailable()) { + alert('The server is not available.') return } const imageElem = imageItemElem.querySelector('img') @@ -473,37 +513,68 @@ async function doMakeImage(task) { const progressBar = task['progressBar'] let res = undefined - let stepUpdate = undefined try { - res = await fetch('/image', { - method: 'POST', + const lastTask = serverState.task + let renderRequest = undefined + do { + res = await fetch('/render', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(reqBody) + }) + renderRequest = await res.json() + // status_code 503, already a task running. + } while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000)) + if (typeof renderRequest?.stream !== 'string') { + console.log('Endpoint response: ', renderRequest) + throw new Error('Endpoint response does not contains a response stream url.') + } + task['taskStatusLabel'].innerText = "Busy/Waiting" + do { // Wait for server status to update. + await asyncDelay(250) + if (!isServerAvailable()) { + throw new Error('Connexion with server lost.') + } + } while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task) + if (serverState.session !== 'pending' && serverState.session !== 'running') { + throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined') + } + do { // Wait for task to start on server. + await asyncDelay(1500) + } while (serverState?.session === 'pending') + + // Task started! + res = await fetch(renderRequest.stream, { headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(reqBody) }) + task['taskStatusLabel'].innerText = "Processing" + task['taskStatusLabel'].classList.add('activeTaskLabel') + task['taskStatusLabel'].classList.remove('waitingTaskLabel') + + let stepUpdate = undefined let reader = res.body.getReader() let textDecoder = new TextDecoder() let finalJSON = '' let prevTime = -1 let readComplete = false - while (true) { - let t = new Date().getTime() - + while (!readComplete || finalJSON.length > 0) { + let t = Date.now() let jsonStr = '' if (!readComplete) { const {value, done} = await reader.read() if (done) { readComplete = true } - if (done && finalJSON.length <= 0 && !value) { - break - } if (value) { jsonStr = textDecoder.decode(value) } } + stepUpdate = undefined try { // hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot. // this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste... @@ -537,9 +608,6 @@ async function doMakeImage(task) { throw e } } - if (readComplete && finalJSON.length <= 0) { - break - } if (typeof stepUpdate === 'object' && 'step' in stepUpdate) { let batchSize = stepUpdate.total_steps let overallStepCount = stepUpdate.step + task.batchesDone * batchSize @@ -564,6 +632,23 @@ async function doMakeImage(task) { showImages(reqBody, stepUpdate, outputContainer, true) } } + if (stepUpdate?.status) { + break + } + if (readComplete && finalJSON.length <= 0) { + if (res.status === 200) { + await asyncDelay(5000) + res = await fetch(renderRequest.stream, { + headers: { + 'Content-Type': 'application/json' + }, + }) + reader = res.body.getReader() + readComplete = false + } else { + console.log('Stream stopped: ', res) + } + } prevTime = t } @@ -580,27 +665,28 @@ async function doMakeImage(task) { 3. Try generating a smaller image.
` } } else { - msg = `Unexpected Read Error:
StepUpdate:${JSON.stringify(stepUpdate, undefined, 4)}
` + msg = `Unexpected Read Error:
StepUpdate: ${JSON.stringify(stepUpdate, undefined, 4)}
` } logError(msg, res, outputMsg) return false } if (typeof stepUpdate !== 'object' || !res || res.status != 200) { - if (serverStatus !== 'online') { + if (!isServerAvailable()) { logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg) } else if (typeof res === 'object') { let msg = 'Stable Diffusion had an error reading the response: ' try { // 'Response': body stream already read msg += 'Read: ' + await res.text() } catch(e) { - msg += 'No error response. ' + msg += 'Unexpected end of stream. ' } if (finalJSON) { msg += 'Buffered data: ' + finalJSON } logError(msg, res, outputMsg) } else { - msg = `Unexpected Read Error:
Response:${res}
StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}
` + let msg = `Unexpected Read Error:
Response: ${res}
StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}
` + logError(msg, res, outputMsg) } progressBar.style.display = 'none' return false @@ -654,8 +740,8 @@ async function checkTasks() { task.isProcessing = true task['stopTask'].innerHTML = ' Stop' - task['taskStatusLabel'].innerText = "Processing" - task['taskStatusLabel'].className += " activeTaskLabel" + task['taskStatusLabel'].innerText = "Starting" + task['taskStatusLabel'].classList.add('waitingTaskLabel') const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1)) const startSeed = task.reqBody.seed || task.seed @@ -780,8 +866,8 @@ function getCurrentUserRequest() { } function makeImage() { - if (serverStatus !== 'online') { - alert('The server is still starting up..') + if (!isServerAvailable()) { + alert('The server is not available.') return } const taskTemplate = getCurrentUserRequest() @@ -834,7 +920,7 @@ function createTask(task) { if (task['isProcessing']) { task.isProcessing = false try { - let res = await fetch('/image/stop') + let res = await fetch('/image/stop?session_id=' + sessionId) } catch (e) { console.log(e) } @@ -1094,9 +1180,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider) updatePromptStrength() useBetaChannelField.addEventListener('click', async function(e) { - if (serverStatus !== 'online') { + if (!isServerAvailable()) { // logError('The server is still starting up..') - alert('The server is still starting up..') + alert('The server is not available.') e.preventDefault() return false } diff --git a/ui/server.py b/ui/server.py index 41bcaee9..3414cf66 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,28 +14,55 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles -from starlette.responses import FileResponse, StreamingResponse +from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel import logging +import queue, threading, time +from typing import Any, Generator, Hashable, Optional, Union from sd_internal import Request, Response app = FastAPI() -model_loaded = False -model_is_loading = False - modifiers_cache = None outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) # don't show access log entries for URLs that start with the given prefix -ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails'] +ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] +NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): return self.__qualname__ + def __str__(self): return self.__name__ +class Symbol(metaclass=SymbolClass): pass + +class ServerStates: + class Init(Symbol): pass + class LoadingModel(Symbol): pass + class Online(Symbol): pass + class Rendering(Symbol): pass + class Unavailable(Symbol): pass + +class RenderTask(): # Task with output queue and completion lock. + def __init__(self, req: Request): + self.request: Request = req # Initial Request + self.response: Any = None # Copy of the last reponse + self.temp_images:[] = [None] * req.num_outputs + self.error: Exception = None + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + +current_state = ServerStates.Init +current_state_error:Exception = None +current_model_path = None +tasks_queue = queue.Queue() + # defaults from https://huggingface.co/blog/stable_diffusion class ImageRequest(BaseModel): session_id: str = "session" @@ -65,38 +92,152 @@ class ImageRequest(BaseModel): stream_progress_updates: bool = False stream_image_progress: bool = False +# Temporary cache to allow to query tasks results for a short time after they are completed. +class TaskCache(): + def __init__(self): + self._base = dict() + def _get_ttl_time(self, ttl: int) -> int: + return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: + return int(time.time()) >= timestamp + def clean(self) -> None: + for key in self._base: + ttl, _ = self._base[key] + if self._is_expired(ttl): + del self._base[key] + def clear(self) -> None: + self._base.clear() + def delete(self, key: Hashable) -> bool: + if key not in self._base: + return False + del self._base[key] + return True + def keep(self, key: Hashable, ttl: int) -> bool: + if key in self._base: + _, value = self._base.get(key) + self._base[key] = (self._get_ttl_time(ttl), value) + return True + return False + def put(self, key: Hashable, value: Any, ttl: int) -> bool: + try: + self._base[key] = ( + self._get_ttl_time(ttl), value + ) + except Exception: + return False + return True + def tryGet(self, key: Hashable) -> Any: + ttl, value = self._base.get(key, (None, None)) + if ttl is not None and self._is_expired(ttl): + self.delete(key) + return None + return value + +task_cache = TaskCache() + class SetAppConfigRequest(BaseModel): update_branch: str = "main" @app.get('/') def read_root(): - headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers) - -@app.get('/ping') -async def ping(): - global model_loaded, model_is_loading + return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) +def preload_model(file_path=None): + global current_state, current_state_error, current_model_path + if file_path == None: + file_path = get_initial_model_to_load() + if file_path == current_model_path: + return + current_state = ServerStates.LoadingModel try: - if model_loaded: - return {'OK'} - - if model_is_loading: - return {'ERROR'} - - model_is_loading = True - from sd_internal import runtime - - runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load()) - - model_loaded = True - model_is_loading = False - - return {'OK'} + runtime.load_model_ckpt(ckpt_to_use=file_path) + current_model_path = file_path + current_state_error = None + current_state = ServerStates.Online except Exception as e: + current_model_path = None + current_state_error = e + current_state = ServerStates.Unavailable print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) + +def thread_render(): + global current_state, current_state_error + from sd_internal import runtime + current_state = ServerStates.Online + preload_model() + while True: + task_cache.clean() + task = None + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + try: + task = tasks_queue.get(timeout=1) + except queue.Empty as e: + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + else: continue + #if current_model_path != task.request.use_stable_diffusion_model: + # preload_model(task.request.use_stable_diffusion_model) + if current_state_error: + task.error = current_state_error + continue + print(f'Session {task.request.session_id} starting task {id(task)}') + current_state = ServerStates.Rendering + try: + task.lock.acquire(blocking=False) + res = runtime.mk_img(task.request) + except Exception as e: + task.error = e + task.lock.release() + tasks_queue.task_done() + print(traceback.format_exc()) + continue + dataQueue = None + if task.request.stream_progress_updates: + dataQueue = task.buffer_queue + for result in res: + if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + runtime.stop_processing = True + if isinstance(current_state_error, StopAsyncIteration): + task.error = current_state_error + current_state_error = None + print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + if dataQueue: + dataQueue.put(result) + if isinstance(result, str): + result = json.loads(result) + task.response = result + if 'output' in result: + for out_obj in result['output']: + if 'path' in out_obj: + img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] + task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] + elif 'data' in out_obj: + task.temp_images[result['output'].index(out_obj)] = out_obj['data'] + task_cache.keep(task.request.session_id, TASK_TTL) + # Task completed + task.lock.release() + tasks_queue.task_done() + task_cache.keep(task.request.session_id, TASK_TTL) + if isinstance(task.error, StopAsyncIteration): + print(f'Session {task.request.session_id} task {id(task)} cancelled!') + elif task.error is not None: + print(f'Session {task.request.session_id} task {id(task)} failed!') + else: + print(f'Session {task.request.session_id} task {id(task)} completed.') + current_state = ServerStates.Online +# Start Rendering Thread +render_thread = threading.Thread(target=thread_render) +render_thread.daemon = True +render_thread.start() + +@app.on_event("shutdown") +def shutdown_event(): # Signal render thread to close on shutdown + global current_state_error + current_state_error = SystemExit('Application shutting down.') # needs to support the legacy installations def get_initial_model_to_load(): @@ -126,7 +267,6 @@ def resolve_model_to_use(model_name): model_path = legacy_model_path else: model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) - return model_path def save_model_to_config(model_name): @@ -135,13 +275,42 @@ def save_model_to_config(model_name): config['model'] = {} config['model']['stable-diffusion'] = model_name - setConfig(config) -@app.post('/image') -def image(req : ImageRequest): - from sd_internal import runtime +@app.get('/ping') # Get server and optionally session status. +def ping(session_id:str=None): + if current_state_error or not render_thread.is_alive(): # Render thread is dead. + return HTTPException(status_code=500, detail=str(current_state_error)) + # Alive + response = {'status': str(current_state)} + if session_id: + task = task_cache.tryGet(session_id) + if task: + response['task'] = id(task) + if task.lock.locked(): + response['session'] = 'running' + elif isinstance(task.error, StopAsyncIteration): + response['session'] = 'stopped' + elif task.error: + response['session'] = 'error' + elif not task.buffer_queue.empty(): + response['session'] = 'buffer' + elif task.response: + response['session'] = 'completed' + else: + response['session'] = 'pending' + return JSONResponse(response, headers=NOCACHE_HEADERS) +@app.post('/render') +def render(req : ImageRequest): + if not render_thread.is_alive(): # Render thread is dead + return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error + # Alive, check if task in cache + task = task_cache.tryGet(req.session_id) + if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one. + return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable + # + from sd_internal import runtime r = Request() r.session_id = req.session_id r.prompt = req.prompt @@ -173,45 +342,70 @@ def image(req : ImageRequest): save_model_to_config(req.use_stable_diffusion_model) + if not req.stream_progress_updates: + r.stream_image_progress = False + + new_task = RenderTask(r) + task_cache.put(r.session_id, new_task, TASK_TTL) + tasks_queue.put(new_task) + + response = { + 'status': str(current_state), + 'queue': tasks_queue.qsize(), + 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', + 'task': id(new_task) + } + return JSONResponse(response, headers=NOCACHE_HEADERS) + +async def read_data_generator(data:queue.Queue, lock:threading.Lock): try: - if not req.stream_progress_updates: - r.stream_image_progress = False + while not data.empty(): + res = data.get(block=False) + data.task_done() + yield res + except queue.Empty as e: yield - res = runtime.mk_img(r) - - if req.stream_progress_updates: - return StreamingResponse(res, media_type='application/json') - else: # compatibility mode: buffer the streaming responses, and return the last one - last_result = None - - for result in res: - last_result = result - - return json.loads(last_result) - except Exception as e: - print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) +@app.get('/image/stream/{session_id:str}/{task_id:int}') +def stream(session_id:str, task_id:int): + #TODO Move to WebSockets ?? + task = task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone + if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + if task.buffer_queue.empty() and not task.lock.locked(): + if task.response: + #print(f'Session {session_id} sending cached response') + return JSONResponse(task.response, headers=NOCACHE_HEADERS) + return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early + #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json') @app.get('/image/stop') -def stop(): - try: - if model_is_loading: - return {'ERROR'} - - from sd_internal import runtime - runtime.stop_processing = True - +def stop(session_id:str=None): + if not session_id: + if current_state == ServerStates.Online or current_state == ServerStates.Unavailable: + return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict + global current_state_error + current_state_error = StopAsyncIteration() return {'OK'} - except Exception as e: - print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) + task = task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict + task.error = StopAsyncIteration('') + return {'OK'} -@app.get('/image/tmp/{session_id}/{img_id}') +@app.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): - from sd_internal import runtime - buf = runtime.temp_images[session_id + '/' + img_id] - buf.seek(0) - return StreamingResponse(buf, media_type='image/jpeg') + task = task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone + if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early + try: + img_data = task.temp_images[img_id] + if isinstance(img_data, str): + return img_data + img_data.seek(0) + return StreamingResponse(img_data, media_type='image/jpeg') + except KeyError as e: + return HTTPException(status_code=500, detail=str(e)) @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): @@ -257,7 +451,6 @@ def getConfig(default_val={}): def setConfig(config): try: config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w') as f: return json.dump(config, f) except: @@ -316,9 +509,7 @@ class LogSuppressFilter(logging.Filter): for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: if path.find(prefix) != -1: return False - return True - logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) # start the browser ui From 476e938d233373a71fe23c2ad9e10408e26a3bc1 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 04:18:34 -0400 Subject: [PATCH 05/20] Forgot a color change for batched tasks. taskStatusLabel could have class activeTaskLabel replace by waitingTaskLabel again. --- ui/media/main.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ui/media/main.js b/ui/media/main.js index 2d8fd6cd..260ef0a3 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -532,6 +532,9 @@ async function doMakeImage(task) { throw new Error('Endpoint response does not contains a response stream url.') } task['taskStatusLabel'].innerText = "Busy/Waiting" + task['taskStatusLabel'].classList.add('waitingTaskLabel') + task['taskStatusLabel'].classList.remove('activeTaskLabel') + do { // Wait for server status to update. await asyncDelay(250) if (!isServerAvailable()) { From f91c77bdc660bb8a5ca13ee20574a042d566638f Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 04:47:13 -0400 Subject: [PATCH 06/20] Failed task go immediately into the buffer state with the error. --- ui/media/main.js | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index 260ef0a3..345287ef 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -541,12 +541,13 @@ async function doMakeImage(task) { throw new Error('Connexion with server lost.') } } while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task) - if (serverState.session !== 'pending' && serverState.session !== 'running') { + if (serverState.session !== 'pending' && serverState.session !== 'running' && serverState.session !== 'buffer') { throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined') } - do { // Wait for task to start on server. + while (serverState?.session === 'pending') { + // Wait for task to start on server. await asyncDelay(1500) - } while (serverState?.session === 'pending') + } // Task started! res = await fetch(renderRequest.stream, { From 4a7260b1be1c3b383f2df7914d6709d1baade799 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 05:20:44 -0400 Subject: [PATCH 07/20] StopAsyncIteration should not trigger HTTP500. Now returns faster into the ready state. --- ui/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ui/server.py b/ui/server.py index 3414cf66..18d5214e 100644 --- a/ui/server.py +++ b/ui/server.py @@ -279,8 +279,10 @@ def save_model_to_config(model_name): @app.get('/ping') # Get server and optionally session status. def ping(session_id:str=None): - if current_state_error or not render_thread.is_alive(): # Render thread is dead. - return HTTPException(status_code=500, detail=str(current_state_error)) + if not render_thread.is_alive(): # Render thread is dead. + if current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) + return HTTPException(status_code=500, detail='Render thread is dead.') + if current_state_error and not isinstance(current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) # Alive response = {'status': str(current_state)} if session_id: From 6ae3b77c2f22e9608af83f938dcae94c23ec314a Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 06:03:18 -0400 Subject: [PATCH 08/20] LoadingModel detection --- ui/server.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ui/server.py b/ui/server.py index 18d5214e..1988cab2 100644 --- a/ui/server.py +++ b/ui/server.py @@ -162,16 +162,16 @@ def preload_model(file_path=None): print(traceback.format_exc()) def thread_render(): - global current_state, current_state_error + global current_state, current_state_error, current_model_path from sd_internal import runtime current_state = ServerStates.Online preload_model() while True: task_cache.clean() - task = None if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable return + task = None try: task = tasks_queue.get(timeout=1) except queue.Empty as e: @@ -185,10 +185,13 @@ def thread_render(): task.error = current_state_error continue print(f'Session {task.request.session_id} starting task {id(task)}') - current_state = ServerStates.Rendering try: task.lock.acquire(blocking=False) res = runtime.mk_img(task.request) + if current_model_path == task.request.use_stable_diffusion_model: + current_state = ServerStates.Rendering + else: + current_state = ServerStates.LoadingModel except Exception as e: task.error = e task.lock.release() @@ -199,6 +202,9 @@ def thread_render(): if task.request.stream_progress_updates: dataQueue = task.buffer_queue for result in res: + if current_state == ServerStates.LoadingModel: + current_state = ServerStates.Rendering + current_model_path = task.request.use_stable_diffusion_model if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): runtime.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): From c7f6763c48405ed663144875af78d679ade653a4 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 23:20:57 -0400 Subject: [PATCH 09/20] Runtime cleanup and moved apply_filters to it's own function --- ui/sd_internal/runtime.py | 150 ++++++++++++++++++++------------------ 1 file changed, 79 insertions(+), 71 deletions(-) diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 63506579..0b0a3003 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -197,6 +197,35 @@ def load_model_real_esrgan(real_esrgan_to_use): print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision) +def get_base_path(disk_path, session_id, prompt, ext, suffix=None): + if disk_path is None: return None + if session_id is None: return None + if ext is None: raise Exception('Missing ext') + + session_out_path = os.path.join(disk_path, session_id) + os.makedirs(session_out_path, exist_ok=True) + + prompt_flattened = filename_regex.sub('_', prompt)[:50] + img_id = str(uuid.uuid4())[-8:] + + if suffix is not None: + return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") + return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") + +def apply_filters(filter_name, image_data): + print(f'Applying filter {filter_name}...') + gc() + + if filter_name == 'gfpgan': + _, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + image_data = output[:,:,::-1] + + if filter_name == 'real_esrgan': + output, _ = model_real_esrgan.enhance(image_data[:,:,::-1]) + image_data = output[:,:,::-1] + + return image_data + def mk_img(req: Request): try: yield from do_mk_img(req) @@ -283,23 +312,11 @@ def do_mk_img(req: Request): opt_prompt = req.prompt opt_seed = req.seed - opt_n_samples = req.num_outputs opt_n_iter = 1 - opt_scale = req.guidance_scale opt_C = 4 - opt_H = req.height - opt_W = req.width opt_f = 8 - opt_ddim_steps = req.num_inference_steps opt_ddim_eta = 0.0 - opt_strength = req.prompt_strength - opt_save_to_disk_path = req.save_to_disk_path opt_init_img = req.init_image - opt_use_face_correction = req.use_face_correction - opt_use_upscale = req.use_upscale - opt_show_only_filtered = req.show_only_filtered_image - opt_format = req.output_format - opt_sampler_name = req.sampler print(req.to_string(), '\n device', device) @@ -307,7 +324,7 @@ def do_mk_img(req: Request): seed_everything(opt_seed) - batch_size = opt_n_samples + batch_size = req.num_outputs prompt = opt_prompt assert prompt is not None data = [batch_size * [prompt]] @@ -327,7 +344,7 @@ def do_mk_img(req: Request): else: handler = _img2img - init_image = load_img(req.init_image, opt_W, opt_H) + init_image = load_img(req.init_image, req.width, req.height) init_image = init_image.to(device) if device != "cpu" and precision == "autocast": @@ -339,7 +356,7 @@ def do_mk_img(req: Request): init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space if req.mask is not None: - mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device) + mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device) mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) mask = repeat(mask, '1 ... -> b ...', b=batch_size) @@ -348,12 +365,12 @@ def do_mk_img(req: Request): move_fs_to_cpu() - assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(opt_strength * opt_ddim_steps) + assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(req.prompt_strength * req.num_inference_steps) print(f"target t_enc is {t_enc} steps") - if opt_save_to_disk_path is not None: - session_out_path = os.path.join(opt_save_to_disk_path, req.session_id) + if req.save_to_disk_path is not None: + session_out_path = os.path.join(req.save_to_disk_path, req.session_id) os.makedirs(session_out_path, exist_ok=True) else: session_out_path = None @@ -366,7 +383,7 @@ def do_mk_img(req: Request): with precision_scope("cuda"): modelCS.to(device) uc = None - if opt_scale != 1.0: + if req.guidance_scale != 1.0: uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -393,7 +410,7 @@ def do_mk_img(req: Request): partial_x_samples = x_samples if req.stream_progress_updates: - n_steps = opt_ddim_steps if req.init_image is None else t_enc + n_steps = req.num_inference_steps if req.init_image is None else t_enc progress = {"step": i, "total_steps": n_steps} if req.stream_image_progress and i % 5 == 0: @@ -425,9 +442,9 @@ def do_mk_img(req: Request): # run the handler try: if handler == _txt2img: - x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name) + x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) else: - x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask) + x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) yield from x_samples @@ -447,69 +464,49 @@ def do_mk_img(req: Request): x_sample = x_sample.astype(np.uint8) img = Image.fromarray(x_sample) - has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \ - (opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')) + has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ + (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) - return_orig_img = not has_filters or not opt_show_only_filtered + return_orig_img = not has_filters or not req.show_only_filtered_image if stop_processing: return_orig_img = True - if opt_save_to_disk_path is not None: - prompt_flattened = filename_regex.sub('_', prompts[0]) - prompt_flattened = prompt_flattened[:50] - - img_id = str(uuid.uuid4())[-8:] - - file_path = f"{prompt_flattened}_{img_id}" - img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}") - meta_out_path = os.path.join(session_out_path, f"{file_path}.txt") - + if req.save_to_disk_path is not None: if return_orig_img: + img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format) save_image(img, img_out_path) - - save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file) + meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], 'txt') + save_metadata(meta_out_path, req, prompts[0], opt_seed) if return_orig_img: - img_data = img_to_base64_str(img, opt_format) + img_data = img_to_base64_str(img, req.output_format) res_image_orig = ResponseImage(data=img_data, seed=opt_seed) res.images.append(res_image_orig) - if opt_save_to_disk_path is not None: + if req.save_to_disk_path is not None: res_image_orig.path_abs = img_out_path del img if has_filters and not stop_processing: - print('Applying filters..') - - gc() filters_applied = [] - - if opt_use_face_correction: - _, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - x_sample = output[:,:,::-1] - filters_applied.append(opt_use_face_correction) - - if opt_use_upscale: - output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1]) - x_sample = output[:,:,::-1] - filters_applied.append(opt_use_upscale) - - filtered_image = Image.fromarray(x_sample) - - filtered_img_data = img_to_base64_str(filtered_image, opt_format) - res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed) - res.images.append(res_image_filtered) - - filters_applied = "_".join(filters_applied) - - if opt_save_to_disk_path is not None: - filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}") - save_image(filtered_image, filtered_img_out_path) - res_image_filtered.path_abs = filtered_img_out_path - - del filtered_image + if req.use_face_correction: + x_sample = apply_filters('gfpgan', x_sample) + filters_applied.append(req.use_face_correction) + if req.use_upscale: + x_sample = apply_filters('real_esrgan', x_sample) + filters_applied.append(req.use_upscale) + if (len(filters_applied) > 0): + filtered_image = Image.fromarray(x_sample) + filtered_img_data = img_to_base64_str(filtered_image, req.output_format) + response_image = ResponseImage(data=filtered_img_data, seed=req.seed) + res.images.append(response_image) + if req.save_to_disk_path is not None: + filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format, "_".join(filters_applied)) + save_image(filtered_image, filtered_img_out_path) + response_image.path_abs = filtered_img_out_path + del filtered_image seeds += str(opt_seed) + "," opt_seed += 1 @@ -529,9 +526,20 @@ def save_image(img, img_out_path): except: print('could not save the file', traceback.format_exc()) -def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file): - metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}" - +def save_metadata(meta_out_path, req, prompt, opt_seed): + metadata = f"""{prompt} +Width: {req.width} +Height: {req.height} +Seed: {opt_seed} +Steps: {req.num_inference_steps} +Guidance Scale: {req.guidance_scale} +Prompt Strength: {req.prompt_strength} +Use Face Correction: {req.use_face_correction} +Use Upscaling: {req.use_upscale} +Sampler: {req.sampler} +Negative Prompt: {req.negative_prompt} +Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'} +""" try: with open(meta_out_path, 'w') as f: f.write(metadata) From ff3db04ab758e5ec93af5927f2919abf40517d32 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Fri, 14 Oct 2022 23:21:44 -0400 Subject: [PATCH 10/20] temp_images needs twice the size if show_only_filtered_image is false --- ui/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/server.py b/ui/server.py index 1988cab2..58acda6c 100644 --- a/ui/server.py +++ b/ui/server.py @@ -53,7 +53,7 @@ class RenderTask(): # Task with output queue and completion lock. def __init__(self, req: Request): self.request: Request = req # Initial Request self.response: Any = None # Copy of the last reponse - self.temp_images:[] = [None] * req.num_outputs + self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.error: Exception = None self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments From 3d4e9613205368a28aa83a8175121b98689b922a Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 00:51:06 -0400 Subject: [PATCH 11/20] time.time() is in seconds not ms. --- ui/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/server.py b/ui/server.py index 58acda6c..3c9b9dd2 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,7 +14,7 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder -TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout +TASK_TTL = 15 * 60 # Discard last session's task timeout from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles From e9f9670eb5a0203d7fdcb25e74d96bbf65246c27 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 01:32:53 -0400 Subject: [PATCH 12/20] Changed '/get' from a query to a path parameter --- ui/media/main.js | 8 ++++---- ui/server.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index 345287ef..c585be5a 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -1213,7 +1213,7 @@ useBetaChannelField.addEventListener('click', async function(e) { async function getAppConfig() { try { - let res = await fetch('/get?key=app_config') + let res = await fetch('/get/app_config') const config = await res.json() if (config.update_branch === 'beta') { @@ -1229,7 +1229,7 @@ async function getAppConfig() { async function getModels() { try { - let res = await fetch('/get?key=models') + let res = await fetch('/get/models') const models = await res.json() let activeModel = models['active'] @@ -1494,7 +1494,7 @@ async function getDiskPath() { return } - let res = await fetch('/get?key=output_dir') + let res = await fetch('/get/output_dir') if (res.status === 200) { res = await res.json() res = res[0] @@ -1613,7 +1613,7 @@ function resizeModifierCards(val) { async function loadModifiers() { try { - let res = await fetch('/get?key=modifiers') + let res = await fetch('/get/modifiers') if (res.status === 200) { res = await res.json() diff --git a/ui/server.py b/ui/server.py index 3c9b9dd2..7522f234 100644 --- a/ui/server.py +++ b/ui/server.py @@ -494,19 +494,19 @@ def getModels(): return models -@app.get('/get') +@app.get('/get/{key:path}') def read_web_data(key:str=None): - if key is None: # /get without parameters, stable-diffusion easter egg. + if not key: # /get without parameters, stable-diffusion easter egg. return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot elif key == 'app_config': config = getConfig(default_val=None) if config is None: return HTTPException(status_code=500, detail="Config file is missing or unreadable") - return config + return JSONResponse(config, headers=NOCACHE_HEADERS) elif key == 'models': - return getModels() + return JSONResponse(getModels(), headers=NOCACHE_HEADERS) elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'output_dir': return {outpath} + elif key == 'output_dir': return JSONResponse({outpath}, headers=NOCACHE_HEADERS) else: return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found From 7de699c7fa5e2782466e42a284666374d815b494 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 03:28:20 -0400 Subject: [PATCH 13/20] Moved a lot of code into task_manager.py --- ui/sd_internal/task_manager.py | 288 ++++++++++++++++++++++++++++ ui/server.py | 332 +++++---------------------------- 2 files changed, 334 insertions(+), 286 deletions(-) create mode 100644 ui/sd_internal/task_manager.py diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py new file mode 100644 index 00000000..67b0bde0 --- /dev/null +++ b/ui/sd_internal/task_manager.py @@ -0,0 +1,288 @@ +import json +import traceback + +TASK_TTL = 15 * 60 # Discard last session's task timeout + +import queue, threading, time +from typing import Any, Generator, Hashable, Optional, Union + +from pydantic import BaseModel +from sd_internal import Request, Response + +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): return self.__qualname__ + def __str__(self): return self.__name__ +class Symbol(metaclass=SymbolClass): pass + +class ServerStates: + class Init(Symbol): pass + class LoadingModel(Symbol): pass + class Online(Symbol): pass + class Rendering(Symbol): pass + class Unavailable(Symbol): pass + +class RenderTask(): # Task with output queue and completion lock. + def __init__(self, req: Request): + self.request: Request = req # Initial Request + self.response: Any = None # Copy of the last reponse + self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) + self.error: Exception = None + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + async def read_buffer_generator(self): + try: + while not self.buffer_queue.empty(): + res = self.buffer_queue.get(block=False) + self.buffer_queue.task_done() + yield res + except queue.Empty as e: yield + +# defaults from https://huggingface.co/blog/stable_diffusion +class ImageRequest(BaseModel): + session_id: str = "session" + prompt: str = "" + negative_prompt: str = "" + init_image: str = None # base64 + mask: str = None # base64 + num_outputs: int = 1 + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + width: int = 512 + height: int = 512 + seed: int = 42 + prompt_strength: float = 0.8 + sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + # allow_nsfw: bool = False + save_to_disk_path: str = None + turbo: bool = True + use_cpu: bool = False + use_full_precision: bool = False + use_face_correction: str = None # or "GFPGANv1.3" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + use_stable_diffusion_model: str = "sd-v1-4" + show_only_filtered_image: bool = False + output_format: str = "jpeg" # or "png" + + stream_progress_updates: bool = False + stream_image_progress: bool = False + +# Temporary cache to allow to query tasks results for a short time after they are completed. +class TaskCache(): + def __init__(self): + self._base = dict() + self._lock: threading.Lock = threading.Lock() + def _get_ttl_time(self, ttl: int) -> int: + return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: + return int(time.time()) >= timestamp + def clean(self) -> None: + self._lock.acquire() + try: + for key in self._base: + ttl, _ = self._base[key] + if self._is_expired(ttl): + del self._base[key] + finally: + self._lock.release() + def clear(self) -> None: + self._lock.acquire() + try: self._base.clear() + finally: self._lock.release() + def delete(self, key: Hashable) -> bool: + self._lock.acquire() + try: + if key not in self._base: + return False + del self._base[key] + return True + finally: + self._lock.release() + def keep(self, key: Hashable, ttl: int) -> bool: + self._lock.acquire() + try: + if key in self._base: + _, value = self._base.get(key) + self._base[key] = (self._get_ttl_time(ttl), value) + return True + return False + finally: + self._lock.release() + def put(self, key: Hashable, value: Any, ttl: int) -> bool: + self._lock.acquire() + try: + self._base[key] = ( + self._get_ttl_time(ttl), value + ) + except Exception: + return False + else: + return True + finally: + self._lock.release() + def tryGet(self, key: Hashable) -> Any: + self._lock.acquire() + try: + ttl, value = self._base.get(key, (None, None)) + if ttl is not None and self._is_expired(ttl): + self.delete(key) + return None + return value + finally: + self._lock.release() + +current_state = ServerStates.Init +current_state_error:Exception = None +current_model_path = None +tasks_queue = queue.Queue() +task_cache = TaskCache() +default_model_to_load = None + +def preload_model(file_path=None): + global current_state, current_state_error, current_model_path + if file_path == None: + file_path = default_model_to_load + if file_path == current_model_path: + return + current_state = ServerStates.LoadingModel + try: + from . import runtime + runtime.load_model_ckpt(ckpt_to_use=file_path) + current_model_path = file_path + current_state_error = None + current_state = ServerStates.Online + except Exception as e: + current_model_path = None + current_state_error = e + current_state = ServerStates.Unavailable + print(traceback.format_exc()) + +def thread_render(): + global current_state, current_state_error, current_model_path + from . import runtime + current_state = ServerStates.Online + preload_model() + while True: + task_cache.clean() + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + task = None + try: + task = tasks_queue.get(timeout=1) + except queue.Empty as e: + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + else: continue + #if current_model_path != task.request.use_stable_diffusion_model: + # preload_model(task.request.use_stable_diffusion_model) + if current_state_error: + task.error = current_state_error + continue + print(f'Session {task.request.session_id} starting task {id(task)}') + try: + task.lock.acquire(blocking=False) + res = runtime.mk_img(task.request) + if current_model_path == task.request.use_stable_diffusion_model: + current_state = ServerStates.Rendering + else: + current_state = ServerStates.LoadingModel + except Exception as e: + task.error = e + task.lock.release() + tasks_queue.task_done() + print(traceback.format_exc()) + continue + dataQueue = None + if task.request.stream_progress_updates: + dataQueue = task.buffer_queue + for result in res: + if current_state == ServerStates.LoadingModel: + current_state = ServerStates.Rendering + current_model_path = task.request.use_stable_diffusion_model + if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + runtime.stop_processing = True + if isinstance(current_state_error, StopAsyncIteration): + task.error = current_state_error + current_state_error = None + print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + if dataQueue: + dataQueue.put(result) + if isinstance(result, str): + result = json.loads(result) + task.response = result + if 'output' in result: + for out_obj in result['output']: + if 'path' in out_obj: + img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] + task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] + elif 'data' in out_obj: + task.temp_images[result['output'].index(out_obj)] = out_obj['data'] + task_cache.keep(task.request.session_id, TASK_TTL) + # Task completed + task.lock.release() + tasks_queue.task_done() + task_cache.keep(task.request.session_id, TASK_TTL) + if isinstance(task.error, StopAsyncIteration): + print(f'Session {task.request.session_id} task {id(task)} cancelled!') + elif task.error is not None: + print(f'Session {task.request.session_id} task {id(task)} failed!') + else: + print(f'Session {task.request.session_id} task {id(task)} completed.') + current_state = ServerStates.Online + +render_thread = threading.Thread(target=thread_render) + +def start_render_thread(): + # Start Rendering Thread + render_thread.daemon = True + render_thread.start() + +def shutdown_event(): # Signal render thread to close on shutdown + global current_state_error + current_state_error = SystemExit('Application shutting down.') + +def render(req : ImageRequest): + if not render_thread.is_alive(): # Render thread is dead + raise ChildProcessError('Rendering thread has died.') + # Alive, check if task in cache + task = task_cache.tryGet(req.session_id) + if task and not task.response and not task.error and not task.lock.locked(): + # Unstarted task pending, deny queueing more than one. + raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') + # + from . import runtime + r = Request() + r.session_id = req.session_id + r.prompt = req.prompt + r.negative_prompt = req.negative_prompt + r.init_image = req.init_image + r.mask = req.mask + r.num_outputs = req.num_outputs + r.num_inference_steps = req.num_inference_steps + r.guidance_scale = req.guidance_scale + r.width = req.width + r.height = req.height + r.seed = req.seed + r.prompt_strength = req.prompt_strength + r.sampler = req.sampler + # r.allow_nsfw = req.allow_nsfw + r.turbo = req.turbo + r.use_cpu = req.use_cpu + r.use_full_precision = req.use_full_precision + r.save_to_disk_path = req.save_to_disk_path + r.use_upscale: str = req.use_upscale + r.use_face_correction = req.use_face_correction + r.show_only_filtered_image = req.show_only_filtered_image + r.output_format = req.output_format + + r.stream_progress_updates = True # the underlying implementation only supports streaming + r.stream_image_progress = req.stream_image_progress + + if not req.stream_progress_updates: + r.stream_image_progress = False + + new_task = RenderTask(r) + task_cache.put(r.session_id, new_task, TASK_TTL) + tasks_queue.put(new_task) + return new_task diff --git a/ui/server.py b/ui/server.py index 7522f234..478b9404 100644 --- a/ui/server.py +++ b/ui/server.py @@ -24,7 +24,7 @@ import logging import queue, threading, time from typing import Any, Generator, Hashable, Optional, Union -from sd_internal import Request, Response +from sd_internal import Request, Response, task_manager app = FastAPI() @@ -37,214 +37,9 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") -class SymbolClass(type): # Print nicely formatted Symbol names. - def __repr__(self): return self.__qualname__ - def __str__(self): return self.__name__ -class Symbol(metaclass=SymbolClass): pass - -class ServerStates: - class Init(Symbol): pass - class LoadingModel(Symbol): pass - class Online(Symbol): pass - class Rendering(Symbol): pass - class Unavailable(Symbol): pass - -class RenderTask(): # Task with output queue and completion lock. - def __init__(self, req: Request): - self.request: Request = req # Initial Request - self.response: Any = None # Copy of the last reponse - self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) - self.error: Exception = None - self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed - self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments - -current_state = ServerStates.Init -current_state_error:Exception = None -current_model_path = None -tasks_queue = queue.Queue() - -# defaults from https://huggingface.co/blog/stable_diffusion -class ImageRequest(BaseModel): - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - save_to_disk_path: str = None - turbo: bool = True - use_cpu: bool = False - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - - stream_progress_updates: bool = False - stream_image_progress: bool = False - -# Temporary cache to allow to query tasks results for a short time after they are completed. -class TaskCache(): - def __init__(self): - self._base = dict() - def _get_ttl_time(self, ttl: int) -> int: - return int(time.time()) + ttl - def _is_expired(self, timestamp: int) -> bool: - return int(time.time()) >= timestamp - def clean(self) -> None: - for key in self._base: - ttl, _ = self._base[key] - if self._is_expired(ttl): - del self._base[key] - def clear(self) -> None: - self._base.clear() - def delete(self, key: Hashable) -> bool: - if key not in self._base: - return False - del self._base[key] - return True - def keep(self, key: Hashable, ttl: int) -> bool: - if key in self._base: - _, value = self._base.get(key) - self._base[key] = (self._get_ttl_time(ttl), value) - return True - return False - def put(self, key: Hashable, value: Any, ttl: int) -> bool: - try: - self._base[key] = ( - self._get_ttl_time(ttl), value - ) - except Exception: - return False - return True - def tryGet(self, key: Hashable) -> Any: - ttl, value = self._base.get(key, (None, None)) - if ttl is not None and self._is_expired(ttl): - self.delete(key) - return None - return value - -task_cache = TaskCache() - class SetAppConfigRequest(BaseModel): update_branch: str = "main" -@app.get('/') -def read_root(): - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) - -def preload_model(file_path=None): - global current_state, current_state_error, current_model_path - if file_path == None: - file_path = get_initial_model_to_load() - if file_path == current_model_path: - return - current_state = ServerStates.LoadingModel - try: - from sd_internal import runtime - runtime.load_model_ckpt(ckpt_to_use=file_path) - current_model_path = file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) - -def thread_render(): - global current_state, current_state_error, current_model_path - from sd_internal import runtime - current_state = ServerStates.Online - preload_model() - while True: - task_cache.clean() - if isinstance(current_state_error, SystemExit): - current_state = ServerStates.Unavailable - return - task = None - try: - task = tasks_queue.get(timeout=1) - except queue.Empty as e: - if isinstance(current_state_error, SystemExit): - current_state = ServerStates.Unavailable - return - else: continue - #if current_model_path != task.request.use_stable_diffusion_model: - # preload_model(task.request.use_stable_diffusion_model) - if current_state_error: - task.error = current_state_error - continue - print(f'Session {task.request.session_id} starting task {id(task)}') - try: - task.lock.acquire(blocking=False) - res = runtime.mk_img(task.request) - if current_model_path == task.request.use_stable_diffusion_model: - current_state = ServerStates.Rendering - else: - current_state = ServerStates.LoadingModel - except Exception as e: - task.error = e - task.lock.release() - tasks_queue.task_done() - print(traceback.format_exc()) - continue - dataQueue = None - if task.request.stream_progress_updates: - dataQueue = task.buffer_queue - for result in res: - if current_state == ServerStates.LoadingModel: - current_state = ServerStates.Rendering - current_model_path = task.request.use_stable_diffusion_model - if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): - runtime.stop_processing = True - if isinstance(current_state_error, StopAsyncIteration): - task.error = current_state_error - current_state_error = None - print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') - if dataQueue: - dataQueue.put(result) - if isinstance(result, str): - result = json.loads(result) - task.response = result - if 'output' in result: - for out_obj in result['output']: - if 'path' in out_obj: - img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] - task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] - elif 'data' in out_obj: - task.temp_images[result['output'].index(out_obj)] = out_obj['data'] - task_cache.keep(task.request.session_id, TASK_TTL) - # Task completed - task.lock.release() - tasks_queue.task_done() - task_cache.keep(task.request.session_id, TASK_TTL) - if isinstance(task.error, StopAsyncIteration): - print(f'Session {task.request.session_id} task {id(task)} cancelled!') - elif task.error is not None: - print(f'Session {task.request.session_id} task {id(task)} failed!') - else: - print(f'Session {task.request.session_id} task {id(task)} completed.') - current_state = ServerStates.Online -# Start Rendering Thread -render_thread = threading.Thread(target=thread_render) -render_thread.daemon = True -render_thread.start() - -@app.on_event("shutdown") -def shutdown_event(): # Signal render thread to close on shutdown - global current_state_error - current_state_error = SystemExit('Application shutting down.') - # needs to support the legacy installations def get_initial_model_to_load(): custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') @@ -261,7 +56,6 @@ def get_initial_model_to_load(): ckpt_to_use = model_path else: print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt') - return ckpt_to_use def resolve_model_to_use(model_name): @@ -275,24 +69,24 @@ def resolve_model_to_use(model_name): model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) return model_path -def save_model_to_config(model_name): - config = getConfig() - if 'model' not in config: - config['model'] = {} +@app.on_event("shutdown") +def shutdown_event(): # Signal render thread to close on shutdown + task_manager.current_state_error = SystemExit('Application shutting down.') - config['model']['stable-diffusion'] = model_name - setConfig(config) +@app.get('/') +def read_root(): + return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) @app.get('/ping') # Get server and optionally session status. def ping(session_id:str=None): - if not render_thread.is_alive(): # Render thread is dead. - if current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) + if not task_manager.render_thread.is_alive(): # Render thread is dead. + if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) return HTTPException(status_code=500, detail='Render thread is dead.') - if current_state_error and not isinstance(current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) # Alive - response = {'status': str(current_state)} + response = {'status': str(task_manager.current_state)} if session_id: - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if task: response['task'] = id(task) if task.lock.locked(): @@ -309,74 +103,38 @@ def ping(session_id:str=None): response['session'] = 'pending' return JSONResponse(response, headers=NOCACHE_HEADERS) +def save_model_to_config(model_name): + config = getConfig() + if 'model' not in config: + config['model'] = {} + + config['model']['stable-diffusion'] = model_name + setConfig(config) + @app.post('/render') -def render(req : ImageRequest): - if not render_thread.is_alive(): # Render thread is dead - return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error - # Alive, check if task in cache - task = task_cache.tryGet(req.session_id) - if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one. - return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable - # - from sd_internal import runtime - r = Request() - r.session_id = req.session_id - r.prompt = req.prompt - r.negative_prompt = req.negative_prompt - r.init_image = req.init_image - r.mask = req.mask - r.num_outputs = req.num_outputs - r.num_inference_steps = req.num_inference_steps - r.guidance_scale = req.guidance_scale - r.width = req.width - r.height = req.height - r.seed = req.seed - r.prompt_strength = req.prompt_strength - r.sampler = req.sampler - # r.allow_nsfw = req.allow_nsfw - r.turbo = req.turbo - r.use_cpu = req.use_cpu - r.use_full_precision = req.use_full_precision - r.save_to_disk_path = req.save_to_disk_path - r.use_upscale: str = req.use_upscale - r.use_face_correction = req.use_face_correction - r.show_only_filtered_image = req.show_only_filtered_image - r.output_format = req.output_format - - r.stream_progress_updates = True # the underlying implementation only supports streaming - r.stream_image_progress = req.stream_image_progress - - r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) - - save_model_to_config(req.use_stable_diffusion_model) - - if not req.stream_progress_updates: - r.stream_image_progress = False - - new_task = RenderTask(r) - task_cache.put(r.session_id, new_task, TASK_TTL) - tasks_queue.put(new_task) - - response = { - 'status': str(current_state), - 'queue': tasks_queue.qsize(), - 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', - 'task': id(new_task) - } - return JSONResponse(response, headers=NOCACHE_HEADERS) - -async def read_data_generator(data:queue.Queue, lock:threading.Lock): +def render(req : task_manager.ImageRequest): try: - while not data.empty(): - res = data.get(block=False) - data.task_done() - yield res - except queue.Empty as e: yield + save_model_to_config(req.use_stable_diffusion_model) + req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) + new_task = task_manager.render(req) + response = { + 'status': str(task_manager.current_state), + 'queue': task_manager.tasks_queue.qsize(), + 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', + 'task': id(new_task) + } + return JSONResponse(response, headers=NOCACHE_HEADERS) + except ChildProcessError as e: # Render thread is dead + return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error + except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one. + return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable + except Exception as e: + return HTTPException(status_code=500, detail=str(e)) @app.get('/image/stream/{session_id:str}/{task_id:int}') def stream(session_id:str, task_id:int): #TODO Move to WebSockets ?? - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): @@ -385,17 +143,16 @@ def stream(session_id:str, task_id:int): return JSONResponse(task.response, headers=NOCACHE_HEADERS) return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') - return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json') + return StreamingResponse(task.read_buffer_generator(), media_type='application/json') @app.get('/image/stop') def stop(session_id:str=None): if not session_id: - if current_state == ServerStates.Online or current_state == ServerStates.Unavailable: + if task_manager.current_state == ServerStates.Online or task_manager.current_state == ServerStates.Unavailable: return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict - global current_state_error - current_state_error = StopAsyncIteration() + task_manager.current_state_error = StopAsyncIteration('') return {'OK'} - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict task.error = StopAsyncIteration('') @@ -403,7 +160,7 @@ def stop(session_id:str=None): @app.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early try: @@ -520,5 +277,8 @@ class LogSuppressFilter(logging.Filter): return True logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) +task_manager.default_model_to_load = get_initial_model_to_load() +task_manager.start_render_thread() + # start the browser ui import webbrowser; webbrowser.open('http://localhost:9000') \ No newline at end of file From 1b324238815824896b6a93178334c917fc57ba8e Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 03:32:00 -0400 Subject: [PATCH 14/20] Renamed a missing ServerStates to task_manager.ServerStates --- ui/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/server.py b/ui/server.py index 478b9404..b4ea05f4 100644 --- a/ui/server.py +++ b/ui/server.py @@ -148,7 +148,7 @@ def stream(session_id:str, task_id:int): @app.get('/image/stop') def stop(session_id:str=None): if not session_id: - if task_manager.current_state == ServerStates.Online or task_manager.current_state == ServerStates.Unavailable: + if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict task_manager.current_state_error = StopAsyncIteration('') return {'OK'} From d3b28c42e6285a750f699a643e01fc3db933d129 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 04:08:17 -0400 Subject: [PATCH 15/20] Better error handling with cache.put --- ui/sd_internal/task_manager.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 67b0bde0..d97379fc 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -113,7 +113,9 @@ class TaskCache(): self._base[key] = ( self._get_ttl_time(ttl), value ) - except Exception: + except Exception as e: + print(str(e)) + print(traceback.format_exc()) return False else: return True @@ -283,6 +285,7 @@ def render(req : ImageRequest): r.stream_image_progress = False new_task = RenderTask(r) - task_cache.put(r.session_id, new_task, TASK_TTL) - tasks_queue.put(new_task) - return new_task + if task_cache.put(r.session_id, new_task, TASK_TTL): + tasks_queue.put(new_task) + return new_task + raise RuntimeError('Failed to add task to cache.') From 8fdb1e7ec981b4f2d81fecdad879a6d8c6472edf Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 04:39:45 -0400 Subject: [PATCH 16/20] Improved locking and logging when cleaning old cached sessions. --- ui/sd_internal/task_manager.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index d97379fc..b3edfd96 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -70,7 +70,7 @@ class ImageRequest(BaseModel): class TaskCache(): def __init__(self): self._base = dict() - self._lock: threading.Lock = threading.Lock() + self._lock: threading.Lock = threading.RLock() def _get_ttl_time(self, ttl: int) -> int: return int(time.time()) + ttl def _is_expired(self, timestamp: int) -> bool: @@ -78,10 +78,16 @@ class TaskCache(): def clean(self) -> None: self._lock.acquire() try: + # Create a list of expired keys to delete + to_delete = [] for key in self._base: ttl, _ = self._base[key] if self._is_expired(ttl): - del self._base[key] + to_delete.append(key) + # Remove Items + for key in to_delete: + del self._base[key] + print(f'Session {key} expired. Data removed.') finally: self._lock.release() def clear(self) -> None: @@ -126,6 +132,7 @@ class TaskCache(): try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): + print(f'Session {key} expired. Discarding data.') self.delete(key) return None return value From 7625e591feebb0d641c17dadd7eab6ab6ec67d83 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 04:47:12 -0400 Subject: [PATCH 17/20] Fixed output_dir not liking the move to JSONResponse --- ui/media/main.js | 2 +- ui/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index c585be5a..952ec1e8 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -1497,7 +1497,7 @@ async function getDiskPath() { let res = await fetch('/get/output_dir') if (res.status === 200) { res = await res.json() - res = res[0] + res = res.output_dir document.querySelector('#diskPath').value = res } diff --git a/ui/server.py b/ui/server.py index b4ea05f4..a803ceb9 100644 --- a/ui/server.py +++ b/ui/server.py @@ -263,7 +263,7 @@ def read_web_data(key:str=None): elif key == 'models': return JSONResponse(getModels(), headers=NOCACHE_HEADERS) elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'output_dir': return JSONResponse({outpath}, headers=NOCACHE_HEADERS) + elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS) else: return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found From cbdf03450d35ce1bb3ddecfaa561b25f3059e755 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 05:31:17 -0400 Subject: [PATCH 18/20] Added timeout to critical locking tasks with matching exception --- ui/sd_internal/task_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index b3edfd96..00364e0c 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -76,7 +76,7 @@ class TaskCache(): def _is_expired(self, timestamp: int) -> bool: return int(time.time()) >= timestamp def clean(self) -> None: - self._lock.acquire() + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.') try: # Create a list of expired keys to delete to_delete = [] @@ -91,11 +91,11 @@ class TaskCache(): finally: self._lock.release() def clear(self) -> None: - self._lock.acquire() + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.') try: self._base.clear() finally: self._lock.release() def delete(self, key: Hashable) -> bool: - self._lock.acquire() + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.') try: if key not in self._base: return False @@ -104,7 +104,7 @@ class TaskCache(): finally: self._lock.release() def keep(self, key: Hashable, ttl: int) -> bool: - self._lock.acquire() + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.') try: if key in self._base: _, value = self._base.get(key) @@ -114,7 +114,7 @@ class TaskCache(): finally: self._lock.release() def put(self, key: Hashable, value: Any, ttl: int) -> bool: - self._lock.acquire() + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.') try: self._base[key] = ( self._get_ttl_time(ttl), value @@ -128,7 +128,7 @@ class TaskCache(): finally: self._lock.release() def tryGet(self, key: Hashable) -> Any: - self._lock.acquire() + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.') try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): @@ -293,6 +293,6 @@ def render(req : ImageRequest): new_task = RenderTask(r) if task_cache.put(r.session_id, new_task, TASK_TTL): - tasks_queue.put(new_task) + tasks_queue.put(new_task, block=True, timeout=30) return new_task raise RuntimeError('Failed to add task to cache.') From 982b5221b19b192fbc92dda7ac652f8c25fcac46 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 05:48:12 -0400 Subject: [PATCH 19/20] Improved serverState tracking --- ui/media/main.js | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ui/media/main.js b/ui/media/main.js index 952ec1e8..b008bfc2 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -117,7 +117,7 @@ maskResetButton.innerHTML = 'Clear' maskResetButton.style.fontWeight = 'normal' maskResetButton.style.fontSize = '10pt' -let serverState = {'status': 'Offline'} +let serverState = {'status': 'Offline', 'time': Date.now()} let activeTags = [] let modifiers = [] let lastPromptUsed = '' @@ -288,6 +288,10 @@ async function healthCheck() { res = await fetch('/ping') } serverState = await res.json() + if (typeof serverState !== 'object' || typeof serverState.status !== 'string') { + serverState = {'status': 'Offline', 'time': Date.now()} + return + } // Set status switch(serverState.status) { case 'Init': @@ -308,6 +312,7 @@ async function healthCheck() { } serverState.time = Date.now() } catch (e) { + serverState = {'status': 'Offline', 'time': Date.now()} setServerStatus('error', 'offline') } } @@ -544,7 +549,7 @@ async function doMakeImage(task) { if (serverState.session !== 'pending' && serverState.session !== 'running' && serverState.session !== 'buffer') { throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined') } - while (serverState?.session === 'pending') { + while (serverState.task === renderRequest.task && serverState.session === 'pending') { // Wait for task to start on server. await asyncDelay(1500) } From 2edc06c662d2c6ee537266d8872f65c6207b1f49 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sun, 16 Oct 2022 21:32:59 -0400 Subject: [PATCH 20/20] Forgot to update UI if failed to get new server state --- ui/media/main.js | 1 + 1 file changed, 1 insertion(+) diff --git a/ui/media/main.js b/ui/media/main.js index b008bfc2..5276db1a 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -290,6 +290,7 @@ async function healthCheck() { serverState = await res.json() if (typeof serverState !== 'object' || typeof serverState.status !== 'string') { serverState = {'status': 'Offline', 'time': Date.now()} + setServerStatus('error', 'offline') return } // Set status