mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-23 05:42:01 +01:00
First draft of multi-task in a single session. (#622)
This commit is contained in:
parent
f8dee7e25f
commit
ba2c966329
@ -383,7 +383,7 @@
|
||||
throw new Error('exception is not an Error or a string.')
|
||||
}
|
||||
}
|
||||
const res = await fetch('/image/stop?session_id=' + SD.sessionId)
|
||||
const res = await fetch('/image/stop?task=' + this.id)
|
||||
if (!res.ok) {
|
||||
console.log('Stop response:', res)
|
||||
throw new Error(res.statusText)
|
||||
@ -556,13 +556,19 @@
|
||||
case TaskStatus.pending:
|
||||
case TaskStatus.waiting:
|
||||
// Wait for server status to include this task.
|
||||
await waitUntil(async () => ((task.#id && serverState.task === task.#id)
|
||||
|| await Promise.resolve(callback?.call(task))
|
||||
|| signal?.aborted),
|
||||
await waitUntil(
|
||||
async () => {
|
||||
if (task.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) {
|
||||
return true
|
||||
}
|
||||
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||
return true
|
||||
}
|
||||
},
|
||||
TASK_STATE_SERVER_UPDATE_DELAY,
|
||||
timeout,
|
||||
)
|
||||
if (this.#id && serverState.task === this.#id) {
|
||||
if (this.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) {
|
||||
this._setStatus(TaskStatus.waiting)
|
||||
}
|
||||
if (await Promise.resolve(callback?.call(this)) || signal?.aborted) {
|
||||
@ -572,19 +578,20 @@
|
||||
return true
|
||||
}
|
||||
// Wait for task to start on server.
|
||||
await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'pending'
|
||||
|| await Promise.resolve(callback?.call(task))
|
||||
|| signal?.aborted),
|
||||
await waitUntil(
|
||||
async () => {
|
||||
if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'pending') {
|
||||
return true
|
||||
}
|
||||
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||
return true
|
||||
}
|
||||
},
|
||||
TASK_STATE_SERVER_UPDATE_DELAY,
|
||||
timeout,
|
||||
)
|
||||
if (serverState.task === this.#id
|
||||
&& (
|
||||
serverState.session === 'running'
|
||||
|| serverState.session === 'buffer'
|
||||
|| serverState.session === 'completed'
|
||||
)
|
||||
) {
|
||||
const state = (typeof serverState.tasks === 'object' ? serverState.tasks[String(task.#id)] : undefined)
|
||||
if (state === 'running' || state === 'buffer' || state === 'completed') {
|
||||
this._setStatus(TaskStatus.processing)
|
||||
}
|
||||
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||
@ -594,9 +601,15 @@
|
||||
return true
|
||||
}
|
||||
case TaskStatus.processing:
|
||||
await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'running'
|
||||
|| await Promise.resolve(callback?.call(task))
|
||||
|| signal?.aborted),
|
||||
await waitUntil(
|
||||
async () => {
|
||||
if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'running') {
|
||||
return true
|
||||
}
|
||||
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||
return true
|
||||
}
|
||||
},
|
||||
TASK_STATE_SERVER_UPDATE_DELAY,
|
||||
timeout,
|
||||
)
|
||||
@ -882,7 +895,8 @@
|
||||
throw e
|
||||
}
|
||||
// Update class status and callback.
|
||||
switch(serverState.session) {
|
||||
const taskState = (typeof serverState.tasks === 'object' ? serverState.tasks[String(this.id)] : undefined)
|
||||
switch(taskState) {
|
||||
case 'pending': // Session has pending tasks.
|
||||
console.error('Server %o render request %o is still waiting.', serverState, renderRequest)
|
||||
//Only update status if not already set by waitUntil
|
||||
@ -915,7 +929,7 @@
|
||||
return false
|
||||
default:
|
||||
if (!progressCallback) {
|
||||
const err = new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
|
||||
const err = new Error('Unexpected server task state: ' + taskState || 'Undefined')
|
||||
this.abort(err)
|
||||
throw err
|
||||
}
|
||||
@ -1065,6 +1079,14 @@
|
||||
return models
|
||||
}
|
||||
|
||||
function getServerCapacity() {
|
||||
let activeDevicesCount = Object.keys(serverState?.devices?.active || {}).length
|
||||
if (window.document.visibilityState === 'hidden') {
|
||||
activeDevicesCount = 1 + activeDevicesCount
|
||||
}
|
||||
return activeDevicesCount
|
||||
}
|
||||
|
||||
function continueTasks() {
|
||||
if (typeof navigator?.scheduling?.isInputPending === 'function') {
|
||||
const inputPendingOptions = {
|
||||
@ -1077,13 +1099,17 @@
|
||||
return asyncDelay(CONCURRENT_TASK_INTERVAL)
|
||||
}
|
||||
}
|
||||
const serverCapacity = getServerCapacity()
|
||||
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
|
||||
eventSource.fireEvent(EVENT_IDLE, {})
|
||||
eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity, idle: true})
|
||||
// Calling idle could result in task being added to queue.
|
||||
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
|
||||
return asyncDelay(IDLE_COOLDOWN)
|
||||
}
|
||||
}
|
||||
if (task_queue.size < serverCapacity) {
|
||||
eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity - task_queue.size})
|
||||
}
|
||||
const completedTasks = []
|
||||
for (let [generator, promise] of concurrent_generators.entries()) {
|
||||
if (promise.isPending) {
|
||||
@ -1122,7 +1148,6 @@
|
||||
concurrent_generators.set(generator, promise)
|
||||
}
|
||||
|
||||
const serverCapacity = 2
|
||||
for (let [task, generator] of task_queue.entries()) {
|
||||
const cTsk = completedTasks.find((item) => item.generator === generator)
|
||||
if (cTsk?.promise?.rejectReason || task.hasFailed) {
|
||||
@ -1211,6 +1236,7 @@
|
||||
removeEventListener: (...args) => eventSource.removeEventListener(...args),
|
||||
|
||||
isServerAvailable,
|
||||
getServerCapacity,
|
||||
|
||||
getSystemInfo,
|
||||
getDevices,
|
||||
@ -1228,6 +1254,14 @@
|
||||
configurable: false,
|
||||
get: () => serverState,
|
||||
},
|
||||
isAvailable: {
|
||||
configurable: false,
|
||||
get: () => isServerAvailable(),
|
||||
},
|
||||
serverCapacity: {
|
||||
configurable: false,
|
||||
get: () => getServerCapacity(),
|
||||
},
|
||||
sessionId: {
|
||||
configurable: false,
|
||||
get: () => sessionId,
|
||||
|
@ -455,9 +455,10 @@ function makeImage() {
|
||||
}
|
||||
|
||||
function onIdle() {
|
||||
const serverCapacity = SD.serverCapacity
|
||||
for (const taskEntry of getUncompletedTaskEntries()) {
|
||||
if (SD.activeTasks.size >= 1) {
|
||||
continue
|
||||
if (SD.activeTasks.size >= serverCapacity) {
|
||||
break
|
||||
}
|
||||
const task = htmlTaskMap.get(taskEntry)
|
||||
if (!task) {
|
||||
|
@ -202,12 +202,12 @@ describe('stable-diffusion-ui', function() {
|
||||
// Wait for server status to update.
|
||||
await SD.waitUntil(() => {
|
||||
console.log('Waiting for %s to be received...', renderRequest.task)
|
||||
return (!SD.serverState.task || SD.serverState.task === renderRequest.task)
|
||||
return (!SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)])
|
||||
}, 250, 10 * 60 * 1000)
|
||||
// Wait for task to start on server.
|
||||
await SD.waitUntil(() => {
|
||||
console.log('Waiting for %s to start...', renderRequest.task)
|
||||
return SD.serverState.task !== renderRequest.task || SD.serverState.session !== 'pending'
|
||||
return !SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)] !== 'pending'
|
||||
}, 250)
|
||||
|
||||
const reader = new SD.ChunkedStreamReader(renderRequest.stream)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
class Request:
|
||||
request_id: str = None
|
||||
session_id: str = "session"
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
|
@ -523,15 +523,16 @@ def update_temp_img(req, x_samples, task_temp_images: list):
|
||||
del img, x_sample, x_sample_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
|
||||
return partial_images
|
||||
|
||||
# Build and return the apropriate generator for do_mk_img
|
||||
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
|
||||
if not req.stream_progress_updates:
|
||||
def empty_callback(x_samples, i): return x_samples
|
||||
def empty_callback(x_samples, i):
|
||||
step_callback()
|
||||
return empty_callback
|
||||
|
||||
thread_data.partial_x_samples = None
|
||||
@ -639,11 +640,6 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
|
||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id)
|
||||
else:
|
||||
session_out_path = None
|
||||
|
||||
with torch.no_grad():
|
||||
for n in trange(opt_n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
@ -37,7 +37,8 @@ class ServerStates:
|
||||
|
||||
class RenderTask(): # Task with output queue and completion lock.
|
||||
def __init__(self, req: Request):
|
||||
self.request: Request = req # Initial Request
|
||||
req.request_id = id(self)
|
||||
self.request: Request = req # Initial Request
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||
@ -51,6 +52,22 @@ class RenderTask(): # Task with output queue and completion lock.
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except queue.Empty as e: yield
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return 'running'
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return 'stopped'
|
||||
if self.error:
|
||||
return 'error'
|
||||
if not self.buffer_queue.empty():
|
||||
return 'buffer'
|
||||
if self.response:
|
||||
return 'completed'
|
||||
return 'pending'
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
||||
class ImageRequest(BaseModel):
|
||||
@ -101,7 +118,7 @@ class FilterRequest(BaseModel):
|
||||
output_quality: int = 75
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class TaskCache():
|
||||
class DataCache():
|
||||
def __init__(self):
|
||||
self._base = dict()
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
@ -110,7 +127,7 @@ class TaskCache():
|
||||
def _is_expired(self, timestamp: int) -> bool:
|
||||
return int(time.time()) >= timestamp
|
||||
def clean(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
# Create a list of expired keys to delete
|
||||
to_delete = []
|
||||
@ -120,16 +137,22 @@ class TaskCache():
|
||||
to_delete.append(key)
|
||||
# Remove Items
|
||||
for key in to_delete:
|
||||
(_, val) = self._base[key]
|
||||
if isinstance(val, RenderTask):
|
||||
print(f'RenderTask {key} expired. Data removed.')
|
||||
elif isinstance(val, SessionState):
|
||||
print(f'Session {key} expired. Data removed.')
|
||||
else:
|
||||
print(f'Key {key} expired. Data removed.')
|
||||
del self._base[key]
|
||||
print(f'Session {key} expired. Data removed.')
|
||||
finally:
|
||||
self._lock.release()
|
||||
def clear(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED)
|
||||
try: self._base.clear()
|
||||
finally: self._lock.release()
|
||||
def delete(self, key: Hashable) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key not in self._base:
|
||||
return False
|
||||
@ -138,7 +161,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key in self._base:
|
||||
_, value = self._base.get(key)
|
||||
@ -148,7 +171,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base[key] = (
|
||||
self._get_ttl_time(ttl), value
|
||||
@ -162,7 +185,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def tryGet(self, key: Hashable) -> Any:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
ttl, value = self._base.get(key, (None, None))
|
||||
if ttl is not None and self._is_expired(ttl):
|
||||
@ -181,11 +204,37 @@ current_model_path = None
|
||||
current_vae_path = None
|
||||
current_hypernetwork_path = None
|
||||
tasks_queue = []
|
||||
task_cache = TaskCache()
|
||||
session_cache = DataCache()
|
||||
task_cache = DataCache()
|
||||
default_model_to_load = None
|
||||
default_vae_to_load = None
|
||||
default_hypernetwork_to_load = None
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
idle_event: threading.Event = threading.Event()
|
||||
|
||||
class SessionState():
|
||||
def __init__(self, id: str):
|
||||
self._id = id
|
||||
self._tasks_ids = []
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
@property
|
||||
def tasks(self):
|
||||
tasks = []
|
||||
for task_id in self._tasks_ids:
|
||||
task = task_cache.tryGet(task_id)
|
||||
if task:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
def put(self, task, ttl=TASK_TTL):
|
||||
task_id = id(task)
|
||||
self._tasks_ids.append(task_id)
|
||||
if not task_cache.put(task_id, task, ttl):
|
||||
return False
|
||||
while len(self._tasks_ids) > len(render_threads) * 2:
|
||||
self._tasks_ids.pop(0)
|
||||
return True
|
||||
|
||||
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
|
||||
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||
@ -268,6 +317,7 @@ def thread_render(device):
|
||||
preload_model()
|
||||
current_state = ServerStates.Online
|
||||
while True:
|
||||
session_cache.clean()
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||
print(f'Shutting down thread for device {runtime.thread_data.device}')
|
||||
@ -279,7 +329,8 @@ def thread_render(device):
|
||||
return
|
||||
task = thread_get_next_task()
|
||||
if task is None:
|
||||
time.sleep(0.05)
|
||||
idle_event.clear()
|
||||
idle_event.wait(timeout=1)
|
||||
continue
|
||||
if task.error is not None:
|
||||
print(task.error)
|
||||
@ -314,10 +365,11 @@ def thread_render(device):
|
||||
current_state_error = None
|
||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
print(traceback.format_exc())
|
||||
@ -325,7 +377,8 @@ def thread_render(device):
|
||||
finally:
|
||||
# Task completed
|
||||
task.lock.release()
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||
elif task.error is not None:
|
||||
@ -334,12 +387,21 @@ def thread_render(device):
|
||||
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.')
|
||||
current_state = ServerStates.Online
|
||||
|
||||
def get_cached_task(session_id:str, update_ttl:bool=False):
|
||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||
# By calling keep before tryGet, wont discard if was expired.
|
||||
if update_ttl and not task_cache.keep(session_id, TASK_TTL):
|
||||
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
|
||||
# Failed to keep task, already gone.
|
||||
return None
|
||||
return task_cache.tryGet(session_id)
|
||||
return task_cache.tryGet(task_id)
|
||||
|
||||
def get_cached_session(session_id:str, update_ttl:bool=False):
|
||||
if update_ttl:
|
||||
session_cache.keep(session_id, TASK_TTL)
|
||||
session = session_cache.tryGet(session_id)
|
||||
if not session:
|
||||
session = SessionState(session_id)
|
||||
session_cache.put(session_id, session, TASK_TTL)
|
||||
return session
|
||||
|
||||
def get_devices():
|
||||
devices = {
|
||||
@ -486,14 +548,16 @@ def shutdown_event(): # Signal render thread to close on shutdown
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
def render(req : ImageRequest):
|
||||
if is_alive() <= 0: # Render thread is dead
|
||||
current_thread_count = is_alive()
|
||||
if current_thread_count <= 0: # Render thread is dead
|
||||
raise ChildProcessError('Rendering thread has died.')
|
||||
|
||||
# Alive, check if task in cache
|
||||
task = task_cache.tryGet(req.session_id)
|
||||
if task and not task.response and not task.error and not task.lock.locked():
|
||||
# Unstarted task pending, deny queueing more than one.
|
||||
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
|
||||
#
|
||||
session = get_cached_session(req.session_id, update_ttl=True)
|
||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||
if current_thread_count < len(pending_tasks):
|
||||
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||
|
||||
from . import runtime
|
||||
r = Request()
|
||||
r.session_id = req.session_id
|
||||
@ -530,13 +594,13 @@ def render(req : ImageRequest):
|
||||
r.stream_image_progress = False
|
||||
|
||||
new_task = RenderTask(r)
|
||||
|
||||
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
||||
if session.put(new_task, TASK_TTL):
|
||||
# Use twice the normal timeout for adding user requests.
|
||||
# Tries to force task_cache.put to fail before tasks_queue.put would.
|
||||
# Tries to force session.put to fail before tasks_queue.put would.
|
||||
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||
try:
|
||||
tasks_queue.append(new_task)
|
||||
idle_event.set()
|
||||
return new_task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
58
ui/server.py
58
ui/server.py
@ -49,14 +49,12 @@ from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
#import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, List, Optional, Union
|
||||
|
||||
from sd_internal import Request, Response, task_manager
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
modifiers_cache = None
|
||||
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
||||
|
||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||
@ -354,21 +352,8 @@ def ping(session_id:str=None):
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
if session_id:
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
||||
if task:
|
||||
response['task'] = id(task)
|
||||
if task.lock.locked():
|
||||
response['session'] = 'running'
|
||||
elif isinstance(task.error, StopAsyncIteration):
|
||||
response['session'] = 'stopped'
|
||||
elif task.error:
|
||||
response['session'] = 'error'
|
||||
elif not task.buffer_queue.empty():
|
||||
response['session'] = 'buffer'
|
||||
elif task.response:
|
||||
response['session'] = 'completed'
|
||||
else:
|
||||
response['session'] = 'pending'
|
||||
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||
response['tasks'] = {id(t): t.status for t in session.tasks}
|
||||
response['devices'] = task_manager.get_devices()
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
@ -408,23 +393,25 @@ def render(req : task_manager.ImageRequest):
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||
'stream': f'/image/stream/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
|
||||
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
|
||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/image/stream/{session_id:str}/{task_id:int}')
|
||||
def stream(session_id:str, task_id:int):
|
||||
@app.get('/image/stream/{task_id:int}')
|
||||
def stream(task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
|
||||
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
|
||||
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if task.buffer_queue.empty() and not task.lock.locked():
|
||||
if task.response:
|
||||
#print(f'Session {session_id} sending cached response')
|
||||
@ -434,22 +421,23 @@ def stream(session_id:str, task_id:int):
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||
|
||||
@app.get('/image/stop')
|
||||
def stop(session_id:str=None):
|
||||
if not session_id:
|
||||
def stop(task: int):
|
||||
if not task:
|
||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=False)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration('')
|
||||
task_id = task
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=False)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
|
||||
return {'OK'}
|
||||
|
||||
@app.get('/image/tmp/{session_id}/{img_id:int}')
|
||||
def get_image(session_id, img_id):
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
|
||||
@app.get('/image/tmp/{task_id:int}/{img_id:int}')
|
||||
def get_image(task_id: int, img_id: int):
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
|
||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||
try:
|
||||
img_data = task.temp_images[img_id]
|
||||
|
Loading…
Reference in New Issue
Block a user