First draft of multi-task in a single session. (#622)

This commit is contained in:
Marc-Andre Ferland 2022-12-08 00:42:46 -05:00 committed by GitHub
parent f8dee7e25f
commit ba2c966329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 179 additions and 95 deletions

View File

@ -383,7 +383,7 @@
throw new Error('exception is not an Error or a string.')
}
}
const res = await fetch('/image/stop?session_id=' + SD.sessionId)
const res = await fetch('/image/stop?task=' + this.id)
if (!res.ok) {
console.log('Stop response:', res)
throw new Error(res.statusText)
@ -556,13 +556,19 @@
case TaskStatus.pending:
case TaskStatus.waiting:
// Wait for server status to include this task.
await waitUntil(async () => ((task.#id && serverState.task === task.#id)
|| await Promise.resolve(callback?.call(task))
|| signal?.aborted),
await waitUntil(
async () => {
if (task.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) {
return true
}
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
return true
}
},
TASK_STATE_SERVER_UPDATE_DELAY,
timeout,
)
if (this.#id && serverState.task === this.#id) {
if (this.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) {
this._setStatus(TaskStatus.waiting)
}
if (await Promise.resolve(callback?.call(this)) || signal?.aborted) {
@ -572,19 +578,20 @@
return true
}
// Wait for task to start on server.
await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'pending'
|| await Promise.resolve(callback?.call(task))
|| signal?.aborted),
await waitUntil(
async () => {
if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'pending') {
return true
}
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
return true
}
},
TASK_STATE_SERVER_UPDATE_DELAY,
timeout,
)
if (serverState.task === this.#id
&& (
serverState.session === 'running'
|| serverState.session === 'buffer'
|| serverState.session === 'completed'
)
) {
const state = (typeof serverState.tasks === 'object' ? serverState.tasks[String(task.#id)] : undefined)
if (state === 'running' || state === 'buffer' || state === 'completed') {
this._setStatus(TaskStatus.processing)
}
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
@ -594,9 +601,15 @@
return true
}
case TaskStatus.processing:
await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'running'
|| await Promise.resolve(callback?.call(task))
|| signal?.aborted),
await waitUntil(
async () => {
if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'running') {
return true
}
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
return true
}
},
TASK_STATE_SERVER_UPDATE_DELAY,
timeout,
)
@ -882,7 +895,8 @@
throw e
}
// Update class status and callback.
switch(serverState.session) {
const taskState = (typeof serverState.tasks === 'object' ? serverState.tasks[String(this.id)] : undefined)
switch(taskState) {
case 'pending': // Session has pending tasks.
console.error('Server %o render request %o is still waiting.', serverState, renderRequest)
//Only update status if not already set by waitUntil
@ -915,7 +929,7 @@
return false
default:
if (!progressCallback) {
const err = new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
const err = new Error('Unexpected server task state: ' + taskState || 'Undefined')
this.abort(err)
throw err
}
@ -1065,6 +1079,14 @@
return models
}
function getServerCapacity() {
let activeDevicesCount = Object.keys(serverState?.devices?.active || {}).length
if (window.document.visibilityState === 'hidden') {
activeDevicesCount = 1 + activeDevicesCount
}
return activeDevicesCount
}
function continueTasks() {
if (typeof navigator?.scheduling?.isInputPending === 'function') {
const inputPendingOptions = {
@ -1077,13 +1099,17 @@
return asyncDelay(CONCURRENT_TASK_INTERVAL)
}
}
const serverCapacity = getServerCapacity()
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
eventSource.fireEvent(EVENT_IDLE, {})
eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity, idle: true})
// Calling idle could result in task being added to queue.
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
return asyncDelay(IDLE_COOLDOWN)
}
}
if (task_queue.size < serverCapacity) {
eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity - task_queue.size})
}
const completedTasks = []
for (let [generator, promise] of concurrent_generators.entries()) {
if (promise.isPending) {
@ -1122,7 +1148,6 @@
concurrent_generators.set(generator, promise)
}
const serverCapacity = 2
for (let [task, generator] of task_queue.entries()) {
const cTsk = completedTasks.find((item) => item.generator === generator)
if (cTsk?.promise?.rejectReason || task.hasFailed) {
@ -1211,6 +1236,7 @@
removeEventListener: (...args) => eventSource.removeEventListener(...args),
isServerAvailable,
getServerCapacity,
getSystemInfo,
getDevices,
@ -1228,6 +1254,14 @@
configurable: false,
get: () => serverState,
},
isAvailable: {
configurable: false,
get: () => isServerAvailable(),
},
serverCapacity: {
configurable: false,
get: () => getServerCapacity(),
},
sessionId: {
configurable: false,
get: () => sessionId,

View File

@ -455,9 +455,10 @@ function makeImage() {
}
function onIdle() {
const serverCapacity = SD.serverCapacity
for (const taskEntry of getUncompletedTaskEntries()) {
if (SD.activeTasks.size >= 1) {
continue
if (SD.activeTasks.size >= serverCapacity) {
break
}
const task = htmlTaskMap.get(taskEntry)
if (!task) {

View File

@ -202,12 +202,12 @@ describe('stable-diffusion-ui', function() {
// Wait for server status to update.
await SD.waitUntil(() => {
console.log('Waiting for %s to be received...', renderRequest.task)
return (!SD.serverState.task || SD.serverState.task === renderRequest.task)
return (!SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)])
}, 250, 10 * 60 * 1000)
// Wait for task to start on server.
await SD.waitUntil(() => {
console.log('Waiting for %s to start...', renderRequest.task)
return SD.serverState.task !== renderRequest.task || SD.serverState.session !== 'pending'
return !SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)] !== 'pending'
}, 250)
const reader = new SD.ChunkedStreamReader(renderRequest.stream)

View File

@ -1,6 +1,7 @@
import json
class Request:
request_id: str = None
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""

View File

@ -523,15 +523,16 @@ def update_temp_img(req, x_samples, task_temp_images: list):
del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
def empty_callback(x_samples, i):
step_callback()
return empty_callback
thread_data.partial_x_samples = None
@ -639,11 +640,6 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
t_enc = int(req.prompt_strength * req.num_inference_steps)
print(f"target t_enc is {t_enc} steps")
if req.save_to_disk_path is not None:
session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id)
else:
session_out_path = None
with torch.no_grad():
for n in trange(opt_n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):

View File

@ -37,7 +37,8 @@ class ServerStates:
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
req.request_id = id(self)
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
@ -51,6 +52,22 @@ class RenderTask(): # Task with output queue and completion lock.
self.buffer_queue.task_done()
yield res
except queue.Empty as e: yield
@property
def status(self):
if self.lock.locked():
return 'running'
if isinstance(self.error, StopAsyncIteration):
return 'stopped'
if self.error:
return 'error'
if not self.buffer_queue.empty():
return 'buffer'
if self.response:
return 'completed'
return 'pending'
@property
def is_pending(self):
return bool(not self.response and not self.error)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
@ -101,7 +118,7 @@ class FilterRequest(BaseModel):
output_quality: int = 75
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
class DataCache():
def __init__(self):
self._base = dict()
self._lock: threading.Lock = threading.Lock()
@ -110,7 +127,7 @@ class TaskCache():
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED)
try:
# Create a list of expired keys to delete
to_delete = []
@ -120,16 +137,22 @@ class TaskCache():
to_delete.append(key)
# Remove Items
for key in to_delete:
(_, val) = self._base[key]
if isinstance(val, RenderTask):
print(f'RenderTask {key} expired. Data removed.')
elif isinstance(val, SessionState):
print(f'Session {key} expired. Data removed.')
else:
print(f'Key {key} expired. Data removed.')
del self._base[key]
print(f'Session {key} expired. Data removed.')
finally:
self._lock.release()
def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED)
try: self._base.clear()
finally: self._lock.release()
def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED)
try:
if key not in self._base:
return False
@ -138,7 +161,7 @@ class TaskCache():
finally:
self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED)
try:
if key in self._base:
_, value = self._base.get(key)
@ -148,7 +171,7 @@ class TaskCache():
finally:
self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED)
try:
self._base[key] = (
self._get_ttl_time(ttl), value
@ -162,7 +185,7 @@ class TaskCache():
finally:
self._lock.release()
def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED)
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
@ -181,11 +204,37 @@ current_model_path = None
current_vae_path = None
current_hypernetwork_path = None
tasks_queue = []
task_cache = TaskCache()
session_cache = DataCache()
task_cache = DataCache()
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
idle_event: threading.Event = threading.Event()
class SessionState():
def __init__(self, id: str):
self._id = id
self._tasks_ids = []
@property
def id(self):
return self._id
@property
def tasks(self):
tasks = []
for task_id in self._tasks_ids:
task = task_cache.tryGet(task_id)
if task:
tasks.append(task)
return tasks
def put(self, task, ttl=TASK_TTL):
task_id = id(task)
self._tasks_ids.append(task_id)
if not task_cache.put(task_id, task, ttl):
return False
while len(self._tasks_ids) > len(render_threads) * 2:
self._tasks_ids.pop(0)
return True
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
@ -268,6 +317,7 @@ def thread_render(device):
preload_model()
current_state = ServerStates.Online
while True:
session_cache.clean()
task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']:
print(f'Shutting down thread for device {runtime.thread_data.device}')
@ -279,7 +329,8 @@ def thread_render(device):
return
task = thread_get_next_task()
if task is None:
time.sleep(0.05)
idle_event.clear()
idle_event.wait(timeout=1)
continue
if task.error is not None:
print(task.error)
@ -314,10 +365,11 @@ def thread_render(device):
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
task_cache.keep(task.request.session_id, TASK_TTL)
current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e:
task.error = e
print(traceback.format_exc())
@ -325,7 +377,8 @@ def thread_render(device):
finally:
# Task completed
task.lock.release()
task_cache.keep(task.request.session_id, TASK_TTL)
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
elif task.error is not None:
@ -334,12 +387,21 @@ def thread_render(device):
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.')
current_state = ServerStates.Online
def get_cached_task(session_id:str, update_ttl:bool=False):
def get_cached_task(task_id:str, update_ttl:bool=False):
# By calling keep before tryGet, wont discard if was expired.
if update_ttl and not task_cache.keep(session_id, TASK_TTL):
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
# Failed to keep task, already gone.
return None
return task_cache.tryGet(session_id)
return task_cache.tryGet(task_id)
def get_cached_session(session_id:str, update_ttl:bool=False):
if update_ttl:
session_cache.keep(session_id, TASK_TTL)
session = session_cache.tryGet(session_id)
if not session:
session = SessionState(session_id)
session_cache.put(session_id, session, TASK_TTL)
return session
def get_devices():
devices = {
@ -486,14 +548,16 @@ def shutdown_event(): # Signal render thread to close on shutdown
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
if is_alive() <= 0: # Render thread is dead
current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
if task and not task.response and not task.error and not task.lock.locked():
# Unstarted task pending, deny queueing more than one.
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
#
session = get_cached_session(req.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks):
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
from . import runtime
r = Request()
r.session_id = req.session_id
@ -530,13 +594,13 @@ def render(req : ImageRequest):
r.stream_image_progress = False
new_task = RenderTask(r)
if task_cache.put(r.session_id, new_task, TASK_TTL):
if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests.
# Tries to force task_cache.put to fail before tasks_queue.put would.
# Tries to force session.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
idle_event.set()
return new_task
finally:
manager_lock.release()

View File

@ -49,14 +49,12 @@ from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
#import queue, threading, time
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager
app = FastAPI()
modifiers_cache = None
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
@ -354,21 +352,8 @@ def ping(session_id:str=None):
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
task = task_manager.get_cached_task(session_id, update_ttl=True)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
session = task_manager.get_cached_session(session_id, update_ttl=True)
response['tasks'] = {id(t): t.status for t in session.tasks}
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
@ -408,23 +393,25 @@ def render(req : task_manager.ImageRequest):
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'stream': f'/image/stream/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
print(e)
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
@app.get('/image/stream/{task_id:int}')
def stream(task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
@ -434,22 +421,23 @@ def stream(session_id:str, task_id:int):
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(session_id:str=None):
if not session_id:
def stop(task: int):
if not task:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task = task_manager.get_cached_task(session_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
task_id = task
task = task_manager.get_cached_task(task_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id):
task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
@app.get('/image/tmp/{task_id:int}/{img_id:int}')
def get_image(task_id: int, img_id: int):
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]