From 9ee0b7fe2e69fd23fe2a0ef6fb3dd6b00c9e3610 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 8 Dec 2022 10:04:14 +0530 Subject: [PATCH 1/4] SD 2.1 --- scripts/on_sd_start.bat | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 2daa39e2..2bcf56d2 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -44,7 +44,7 @@ if NOT DEFINED test_sd2 set test_sd2=N @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a ) if "%test_sd2%" == "Y" ( - @call git -c advice.detachedHead=false checkout b1a80dfc75388914252ce363f923103185eaf48f + @call git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f ) @cd .. From a8151176d779bbb74e64dfb6ea64ca43db918e20 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 8 Dec 2022 10:04:33 +0530 Subject: [PATCH 2/4] SD 2.1 --- scripts/on_sd_start.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index f8f3d560..8682c5cc 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -38,7 +38,7 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta if [ "$test_sd2" == "N" ]; then git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a elif [ "$test_sd2" == "Y" ]; then - git -c advice.detachedHead=false checkout b1a80dfc75388914252ce363f923103185eaf48f + git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f fi cd .. From f8dee7e25f8816d913b0e1e7da54f13c57e728d5 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Thu, 8 Dec 2022 00:27:50 -0500 Subject: [PATCH 3/4] Add test sample to one of the plugin. (#626) * Added test example from a plugin. * Only load style if #news was created. --- ui/plugins/ui/release-notes.plugin.js | 31 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/ui/plugins/ui/release-notes.plugin.js b/ui/plugins/ui/release-notes.plugin.js index 9cd07659..cfe2b338 100644 --- a/ui/plugins/ui/release-notes.plugin.js +++ b/ui/plugins/ui/release-notes.plugin.js @@ -1,4 +1,14 @@ (function() { + // Register selftests when loaded by jasmine. + if (typeof PLUGINS?.SELFTEST === 'object') { + PLUGINS.SELFTEST["release-notes"] = function() { + it('should be able to fetch CHANGES.md', async function() { + let releaseNotes = await fetch(`https://raw.githubusercontent.com/cmdr2/stable-diffusion-ui/main/CHANGES.md`) + expect(releaseNotes.status).toBe(200) + }) + } + } + document.querySelector('#tab-container')?.insertAdjacentHTML('beforeend', ` What's new? @@ -13,7 +23,17 @@ `) - document.querySelector('body')?.insertAdjacentHTML('beforeend', ` + const tabNews = document.querySelector('#tab-news') + if (tabNews) { + linkTabContents(tabNews) + } + const news = document.querySelector('#news') + if (!news) { + // news tab not found, dont exec plugin code. + return + } + + document.querySelector('body').insertAdjacentHTML('beforeend', ` `) - const tabNews = document.querySelector('#tab-news') - if (tabNews) { - linkTabContents(tabNews) - } - const news = document.querySelector('#news') - if (!news) { - return - } - const markedScript = document.createElement('script') markedScript.src = '/media/js/marked.min.js' From ba2c9663292fb2e3d85c6aa7f2442982ebf1e52b Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Thu, 8 Dec 2022 00:42:46 -0500 Subject: [PATCH 4/4] First draft of multi-task in a single session. (#622) --- ui/media/js/engine.js | 78 +++++++++++++++------- ui/media/js/main.js | 5 +- ui/plugins/ui/jasmineSpec.js | 4 +- ui/sd_internal/__init__.py | 1 + ui/sd_internal/runtime.py | 12 ++-- ui/sd_internal/task_manager.py | 116 +++++++++++++++++++++++++-------- ui/server.py | 58 +++++++---------- 7 files changed, 179 insertions(+), 95 deletions(-) diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js index f357b8bc..dd34ddb1 100644 --- a/ui/media/js/engine.js +++ b/ui/media/js/engine.js @@ -383,7 +383,7 @@ throw new Error('exception is not an Error or a string.') } } - const res = await fetch('/image/stop?session_id=' + SD.sessionId) + const res = await fetch('/image/stop?task=' + this.id) if (!res.ok) { console.log('Stop response:', res) throw new Error(res.statusText) @@ -556,13 +556,19 @@ case TaskStatus.pending: case TaskStatus.waiting: // Wait for server status to include this task. - await waitUntil(async () => ((task.#id && serverState.task === task.#id) - || await Promise.resolve(callback?.call(task)) - || signal?.aborted), + await waitUntil( + async () => { + if (task.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) { + return true + } + if (await Promise.resolve(callback?.call(task)) || signal?.aborted) { + return true + } + }, TASK_STATE_SERVER_UPDATE_DELAY, timeout, ) - if (this.#id && serverState.task === this.#id) { + if (this.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) { this._setStatus(TaskStatus.waiting) } if (await Promise.resolve(callback?.call(this)) || signal?.aborted) { @@ -572,19 +578,20 @@ return true } // Wait for task to start on server. - await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'pending' - || await Promise.resolve(callback?.call(task)) - || signal?.aborted), + await waitUntil( + async () => { + if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'pending') { + return true + } + if (await Promise.resolve(callback?.call(task)) || signal?.aborted) { + return true + } + }, TASK_STATE_SERVER_UPDATE_DELAY, timeout, ) - if (serverState.task === this.#id - && ( - serverState.session === 'running' - || serverState.session === 'buffer' - || serverState.session === 'completed' - ) - ) { + const state = (typeof serverState.tasks === 'object' ? serverState.tasks[String(task.#id)] : undefined) + if (state === 'running' || state === 'buffer' || state === 'completed') { this._setStatus(TaskStatus.processing) } if (await Promise.resolve(callback?.call(task)) || signal?.aborted) { @@ -594,9 +601,15 @@ return true } case TaskStatus.processing: - await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'running' - || await Promise.resolve(callback?.call(task)) - || signal?.aborted), + await waitUntil( + async () => { + if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'running') { + return true + } + if (await Promise.resolve(callback?.call(task)) || signal?.aborted) { + return true + } + }, TASK_STATE_SERVER_UPDATE_DELAY, timeout, ) @@ -882,7 +895,8 @@ throw e } // Update class status and callback. - switch(serverState.session) { + const taskState = (typeof serverState.tasks === 'object' ? serverState.tasks[String(this.id)] : undefined) + switch(taskState) { case 'pending': // Session has pending tasks. console.error('Server %o render request %o is still waiting.', serverState, renderRequest) //Only update status if not already set by waitUntil @@ -915,7 +929,7 @@ return false default: if (!progressCallback) { - const err = new Error('Unexpected server task state: ' + serverState.session || 'Undefined') + const err = new Error('Unexpected server task state: ' + taskState || 'Undefined') this.abort(err) throw err } @@ -1065,6 +1079,14 @@ return models } + function getServerCapacity() { + let activeDevicesCount = Object.keys(serverState?.devices?.active || {}).length + if (window.document.visibilityState === 'hidden') { + activeDevicesCount = 1 + activeDevicesCount + } + return activeDevicesCount + } + function continueTasks() { if (typeof navigator?.scheduling?.isInputPending === 'function') { const inputPendingOptions = { @@ -1077,13 +1099,17 @@ return asyncDelay(CONCURRENT_TASK_INTERVAL) } } + const serverCapacity = getServerCapacity() if (task_queue.size <= 0 && concurrent_generators.size <= 0) { - eventSource.fireEvent(EVENT_IDLE, {}) + eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity, idle: true}) // Calling idle could result in task being added to queue. if (task_queue.size <= 0 && concurrent_generators.size <= 0) { return asyncDelay(IDLE_COOLDOWN) } } + if (task_queue.size < serverCapacity) { + eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity - task_queue.size}) + } const completedTasks = [] for (let [generator, promise] of concurrent_generators.entries()) { if (promise.isPending) { @@ -1122,7 +1148,6 @@ concurrent_generators.set(generator, promise) } - const serverCapacity = 2 for (let [task, generator] of task_queue.entries()) { const cTsk = completedTasks.find((item) => item.generator === generator) if (cTsk?.promise?.rejectReason || task.hasFailed) { @@ -1211,6 +1236,7 @@ removeEventListener: (...args) => eventSource.removeEventListener(...args), isServerAvailable, + getServerCapacity, getSystemInfo, getDevices, @@ -1228,6 +1254,14 @@ configurable: false, get: () => serverState, }, + isAvailable: { + configurable: false, + get: () => isServerAvailable(), + }, + serverCapacity: { + configurable: false, + get: () => getServerCapacity(), + }, sessionId: { configurable: false, get: () => sessionId, diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 7192b0d3..67ff7a9a 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -455,9 +455,10 @@ function makeImage() { } function onIdle() { + const serverCapacity = SD.serverCapacity for (const taskEntry of getUncompletedTaskEntries()) { - if (SD.activeTasks.size >= 1) { - continue + if (SD.activeTasks.size >= serverCapacity) { + break } const task = htmlTaskMap.get(taskEntry) if (!task) { diff --git a/ui/plugins/ui/jasmineSpec.js b/ui/plugins/ui/jasmineSpec.js index 27329380..b97bbd4c 100644 --- a/ui/plugins/ui/jasmineSpec.js +++ b/ui/plugins/ui/jasmineSpec.js @@ -202,12 +202,12 @@ describe('stable-diffusion-ui', function() { // Wait for server status to update. await SD.waitUntil(() => { console.log('Waiting for %s to be received...', renderRequest.task) - return (!SD.serverState.task || SD.serverState.task === renderRequest.task) + return (!SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)]) }, 250, 10 * 60 * 1000) // Wait for task to start on server. await SD.waitUntil(() => { console.log('Waiting for %s to start...', renderRequest.task) - return SD.serverState.task !== renderRequest.task || SD.serverState.session !== 'pending' + return !SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)] !== 'pending' }, 250) const reader = new SD.ChunkedStreamReader(renderRequest.stream) diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index a2abe294..0a1590f0 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,6 +1,7 @@ import json class Request: + request_id: str = None session_id: str = "session" prompt: str = "" negative_prompt: str = "" diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 30d19ef1..d3fd6e86 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -523,15 +523,16 @@ def update_temp_img(req, x_samples, task_temp_images: list): del img, x_sample, x_sample_ddim # don't delete x_samples, it is used in the code that called this callback - thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf + thread_data.temp_images[f'{req.request_id}/{i}'] = buf task_temp_images[i] = buf - partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) + partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) return partial_images # Build and return the apropriate generator for do_mk_img def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None): if not req.stream_progress_updates: - def empty_callback(x_samples, i): return x_samples + def empty_callback(x_samples, i): + step_callback() return empty_callback thread_data.partial_x_samples = None @@ -639,11 +640,6 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste t_enc = int(req.prompt_strength * req.num_inference_steps) print(f"target t_enc is {t_enc} steps") - if req.save_to_disk_path is not None: - session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id) - else: - session_out_path = None - with torch.no_grad(): for n in trange(opt_n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 41fc00f6..434394cf 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -37,7 +37,8 @@ class ServerStates: class RenderTask(): # Task with output queue and completion lock. def __init__(self, req: Request): - self.request: Request = req # Initial Request + req.request_id = id(self) + self.request: Request = req # Initial Request self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) @@ -51,6 +52,22 @@ class RenderTask(): # Task with output queue and completion lock. self.buffer_queue.task_done() yield res except queue.Empty as e: yield + @property + def status(self): + if self.lock.locked(): + return 'running' + if isinstance(self.error, StopAsyncIteration): + return 'stopped' + if self.error: + return 'error' + if not self.buffer_queue.empty(): + return 'buffer' + if self.response: + return 'completed' + return 'pending' + @property + def is_pending(self): + return bool(not self.response and not self.error) # defaults from https://huggingface.co/blog/stable_diffusion class ImageRequest(BaseModel): @@ -101,7 +118,7 @@ class FilterRequest(BaseModel): output_quality: int = 75 # Temporary cache to allow to query tasks results for a short time after they are completed. -class TaskCache(): +class DataCache(): def __init__(self): self._base = dict() self._lock: threading.Lock = threading.Lock() @@ -110,7 +127,7 @@ class TaskCache(): def _is_expired(self, timestamp: int) -> bool: return int(time.time()) >= timestamp def clean(self) -> None: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED) try: # Create a list of expired keys to delete to_delete = [] @@ -120,16 +137,22 @@ class TaskCache(): to_delete.append(key) # Remove Items for key in to_delete: + (_, val) = self._base[key] + if isinstance(val, RenderTask): + print(f'RenderTask {key} expired. Data removed.') + elif isinstance(val, SessionState): + print(f'Session {key} expired. Data removed.') + else: + print(f'Key {key} expired. Data removed.') del self._base[key] - print(f'Session {key} expired. Data removed.') finally: self._lock.release() def clear(self) -> None: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED) try: self._base.clear() finally: self._lock.release() def delete(self, key: Hashable) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED) try: if key not in self._base: return False @@ -138,7 +161,7 @@ class TaskCache(): finally: self._lock.release() def keep(self, key: Hashable, ttl: int) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED) try: if key in self._base: _, value = self._base.get(key) @@ -148,7 +171,7 @@ class TaskCache(): finally: self._lock.release() def put(self, key: Hashable, value: Any, ttl: int) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED) try: self._base[key] = ( self._get_ttl_time(ttl), value @@ -162,7 +185,7 @@ class TaskCache(): finally: self._lock.release() def tryGet(self, key: Hashable) -> Any: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED) try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): @@ -181,11 +204,37 @@ current_model_path = None current_vae_path = None current_hypernetwork_path = None tasks_queue = [] -task_cache = TaskCache() +session_cache = DataCache() +task_cache = DataCache() default_model_to_load = None default_vae_to_load = None default_hypernetwork_to_load = None weak_thread_data = weakref.WeakKeyDictionary() +idle_event: threading.Event = threading.Event() + +class SessionState(): + def __init__(self, id: str): + self._id = id + self._tasks_ids = [] + @property + def id(self): + return self._id + @property + def tasks(self): + tasks = [] + for task_id in self._tasks_ids: + task = task_cache.tryGet(task_id) + if task: + tasks.append(task) + return tasks + def put(self, task, ttl=TASK_TTL): + task_id = id(task) + self._tasks_ids.append(task_id) + if not task_cache.put(task_id, task, ttl): + return False + while len(self._tasks_ids) > len(render_threads) * 2: + self._tasks_ids.pop(0) + return True def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None): global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path @@ -268,6 +317,7 @@ def thread_render(device): preload_model() current_state = ServerStates.Online while True: + session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: print(f'Shutting down thread for device {runtime.thread_data.device}') @@ -279,7 +329,8 @@ def thread_render(device): return task = thread_get_next_task() if task is None: - time.sleep(0.05) + idle_event.clear() + idle_event.wait(timeout=1) continue if task.error is not None: print(task.error) @@ -314,10 +365,11 @@ def thread_render(device): current_state_error = None print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') - task_cache.keep(task.request.session_id, TASK_TTL) - current_state = ServerStates.Rendering task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) + # Before looping back to the generator, mark cache as still alive. + task_cache.keep(id(task), TASK_TTL) + session_cache.keep(task.request.session_id, TASK_TTL) except Exception as e: task.error = e print(traceback.format_exc()) @@ -325,7 +377,8 @@ def thread_render(device): finally: # Task completed task.lock.release() - task_cache.keep(task.request.session_id, TASK_TTL) + task_cache.keep(id(task), TASK_TTL) + session_cache.keep(task.request.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): print(f'Session {task.request.session_id} task {id(task)} cancelled!') elif task.error is not None: @@ -334,12 +387,21 @@ def thread_render(device): print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') current_state = ServerStates.Online -def get_cached_task(session_id:str, update_ttl:bool=False): +def get_cached_task(task_id:str, update_ttl:bool=False): # By calling keep before tryGet, wont discard if was expired. - if update_ttl and not task_cache.keep(session_id, TASK_TTL): + if update_ttl and not task_cache.keep(task_id, TASK_TTL): # Failed to keep task, already gone. return None - return task_cache.tryGet(session_id) + return task_cache.tryGet(task_id) + +def get_cached_session(session_id:str, update_ttl:bool=False): + if update_ttl: + session_cache.keep(session_id, TASK_TTL) + session = session_cache.tryGet(session_id) + if not session: + session = SessionState(session_id) + session_cache.put(session_id, session, TASK_TTL) + return session def get_devices(): devices = { @@ -486,14 +548,16 @@ def shutdown_event(): # Signal render thread to close on shutdown current_state_error = SystemExit('Application shutting down.') def render(req : ImageRequest): - if is_alive() <= 0: # Render thread is dead + current_thread_count = is_alive() + if current_thread_count <= 0: # Render thread is dead raise ChildProcessError('Rendering thread has died.') + # Alive, check if task in cache - task = task_cache.tryGet(req.session_id) - if task and not task.response and not task.error and not task.lock.locked(): - # Unstarted task pending, deny queueing more than one. - raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') - # + session = get_cached_session(req.session_id, update_ttl=True) + pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) + if current_thread_count < len(pending_tasks): + raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') + from . import runtime r = Request() r.session_id = req.session_id @@ -530,13 +594,13 @@ def render(req : ImageRequest): r.stream_image_progress = False new_task = RenderTask(r) - - if task_cache.put(r.session_id, new_task, TASK_TTL): + if session.put(new_task, TASK_TTL): # Use twice the normal timeout for adding user requests. - # Tries to force task_cache.put to fail before tasks_queue.put would. + # Tries to force session.put to fail before tasks_queue.put would. if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2): try: tasks_queue.append(new_task) + idle_event.set() return new_task finally: manager_lock.release() diff --git a/ui/server.py b/ui/server.py index 804994a2..2db312de 100644 --- a/ui/server.py +++ b/ui/server.py @@ -49,14 +49,12 @@ from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel import logging -#import queue, threading, time from typing import Any, Generator, Hashable, List, Optional, Union from sd_internal import Request, Response, task_manager app = FastAPI() -modifiers_cache = None outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) @@ -354,21 +352,8 @@ def ping(session_id:str=None): # Alive response = {'status': str(task_manager.current_state)} if session_id: - task = task_manager.get_cached_task(session_id, update_ttl=True) - if task: - response['task'] = id(task) - if task.lock.locked(): - response['session'] = 'running' - elif isinstance(task.error, StopAsyncIteration): - response['session'] = 'stopped' - elif task.error: - response['session'] = 'error' - elif not task.buffer_queue.empty(): - response['session'] = 'buffer' - elif task.response: - response['session'] = 'completed' - else: - response['session'] = 'pending' + session = task_manager.get_cached_session(session_id, update_ttl=True) + response['tasks'] = {id(t): t.status for t in session.tasks} response['devices'] = task_manager.get_devices() return JSONResponse(response, headers=NOCACHE_HEADERS) @@ -408,23 +393,25 @@ def render(req : task_manager.ImageRequest): response = { 'status': str(task_manager.current_state), 'queue': len(task_manager.tasks_queue), - 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', + 'stream': f'/image/stream/{id(new_task)}', 'task': id(new_task) } return JSONResponse(response, headers=NOCACHE_HEADERS) except ChildProcessError as e: # Render thread is dead raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error - except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one. - raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable + except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. + raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable except Exception as e: + print(e) + print(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) -@app.get('/image/stream/{session_id:str}/{task_id:int}') -def stream(session_id:str, task_id:int): +@app.get('/image/stream/{task_id:int}') +def stream(task_id:int): #TODO Move to WebSockets ?? - task = task_manager.get_cached_task(session_id, update_ttl=True) - if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone - if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + task = task_manager.get_cached_task(task_id, update_ttl=True) + if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound + #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): if task.response: #print(f'Session {session_id} sending cached response') @@ -434,22 +421,23 @@ def stream(session_id:str, task_id:int): return StreamingResponse(task.read_buffer_generator(), media_type='application/json') @app.get('/image/stop') -def stop(session_id:str=None): - if not session_id: +def stop(task: int): + if not task: if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict task_manager.current_state_error = StopAsyncIteration('') return {'OK'} - task = task_manager.get_cached_task(session_id, update_ttl=False) - if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found - if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict - task.error = StopAsyncIteration('') + task_id = task + task = task_manager.get_cached_task(task_id, update_ttl=False) + if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict + task.error = StopAsyncIteration(f'Task {task_id} stop requested.') return {'OK'} -@app.get('/image/tmp/{session_id}/{img_id:int}') -def get_image(session_id, img_id): - task = task_manager.get_cached_task(session_id, update_ttl=True) - if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone +@app.get('/image/tmp/{task_id:int}/{img_id:int}') +def get_image(task_id: int, img_id: int): + task = task_manager.get_cached_task(task_id, update_ttl=True) + if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early try: img_data = task.temp_images[img_id]