mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-16 15:38:37 +02:00
First draft of multi-task in a single session. (#622)
This commit is contained in:
parent
f8dee7e25f
commit
ba2c966329
@ -383,7 +383,7 @@
|
|||||||
throw new Error('exception is not an Error or a string.')
|
throw new Error('exception is not an Error or a string.')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const res = await fetch('/image/stop?session_id=' + SD.sessionId)
|
const res = await fetch('/image/stop?task=' + this.id)
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
console.log('Stop response:', res)
|
console.log('Stop response:', res)
|
||||||
throw new Error(res.statusText)
|
throw new Error(res.statusText)
|
||||||
@ -556,13 +556,19 @@
|
|||||||
case TaskStatus.pending:
|
case TaskStatus.pending:
|
||||||
case TaskStatus.waiting:
|
case TaskStatus.waiting:
|
||||||
// Wait for server status to include this task.
|
// Wait for server status to include this task.
|
||||||
await waitUntil(async () => ((task.#id && serverState.task === task.#id)
|
await waitUntil(
|
||||||
|| await Promise.resolve(callback?.call(task))
|
async () => {
|
||||||
|| signal?.aborted),
|
if (task.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
},
|
||||||
TASK_STATE_SERVER_UPDATE_DELAY,
|
TASK_STATE_SERVER_UPDATE_DELAY,
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
if (this.#id && serverState.task === this.#id) {
|
if (this.#id && typeof serverState.tasks === 'object' && Object.keys(serverState.tasks).includes(String(task.#id))) {
|
||||||
this._setStatus(TaskStatus.waiting)
|
this._setStatus(TaskStatus.waiting)
|
||||||
}
|
}
|
||||||
if (await Promise.resolve(callback?.call(this)) || signal?.aborted) {
|
if (await Promise.resolve(callback?.call(this)) || signal?.aborted) {
|
||||||
@ -572,19 +578,20 @@
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Wait for task to start on server.
|
// Wait for task to start on server.
|
||||||
await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'pending'
|
await waitUntil(
|
||||||
|| await Promise.resolve(callback?.call(task))
|
async () => {
|
||||||
|| signal?.aborted),
|
if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'pending') {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
},
|
||||||
TASK_STATE_SERVER_UPDATE_DELAY,
|
TASK_STATE_SERVER_UPDATE_DELAY,
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
if (serverState.task === this.#id
|
const state = (typeof serverState.tasks === 'object' ? serverState.tasks[String(task.#id)] : undefined)
|
||||||
&& (
|
if (state === 'running' || state === 'buffer' || state === 'completed') {
|
||||||
serverState.session === 'running'
|
|
||||||
|| serverState.session === 'buffer'
|
|
||||||
|| serverState.session === 'completed'
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
this._setStatus(TaskStatus.processing)
|
this._setStatus(TaskStatus.processing)
|
||||||
}
|
}
|
||||||
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||||
@ -594,9 +601,15 @@
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case TaskStatus.processing:
|
case TaskStatus.processing:
|
||||||
await waitUntil(async () => (serverState.task !== task.#id || serverState.session !== 'running'
|
await waitUntil(
|
||||||
|| await Promise.resolve(callback?.call(task))
|
async () => {
|
||||||
|| signal?.aborted),
|
if (typeof serverState.tasks !== 'object' || serverState.tasks[String(task.#id)] !== 'running') {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (await Promise.resolve(callback?.call(task)) || signal?.aborted) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
},
|
||||||
TASK_STATE_SERVER_UPDATE_DELAY,
|
TASK_STATE_SERVER_UPDATE_DELAY,
|
||||||
timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
@ -882,7 +895,8 @@
|
|||||||
throw e
|
throw e
|
||||||
}
|
}
|
||||||
// Update class status and callback.
|
// Update class status and callback.
|
||||||
switch(serverState.session) {
|
const taskState = (typeof serverState.tasks === 'object' ? serverState.tasks[String(this.id)] : undefined)
|
||||||
|
switch(taskState) {
|
||||||
case 'pending': // Session has pending tasks.
|
case 'pending': // Session has pending tasks.
|
||||||
console.error('Server %o render request %o is still waiting.', serverState, renderRequest)
|
console.error('Server %o render request %o is still waiting.', serverState, renderRequest)
|
||||||
//Only update status if not already set by waitUntil
|
//Only update status if not already set by waitUntil
|
||||||
@ -915,7 +929,7 @@
|
|||||||
return false
|
return false
|
||||||
default:
|
default:
|
||||||
if (!progressCallback) {
|
if (!progressCallback) {
|
||||||
const err = new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
|
const err = new Error('Unexpected server task state: ' + taskState || 'Undefined')
|
||||||
this.abort(err)
|
this.abort(err)
|
||||||
throw err
|
throw err
|
||||||
}
|
}
|
||||||
@ -1065,6 +1079,14 @@
|
|||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getServerCapacity() {
|
||||||
|
let activeDevicesCount = Object.keys(serverState?.devices?.active || {}).length
|
||||||
|
if (window.document.visibilityState === 'hidden') {
|
||||||
|
activeDevicesCount = 1 + activeDevicesCount
|
||||||
|
}
|
||||||
|
return activeDevicesCount
|
||||||
|
}
|
||||||
|
|
||||||
function continueTasks() {
|
function continueTasks() {
|
||||||
if (typeof navigator?.scheduling?.isInputPending === 'function') {
|
if (typeof navigator?.scheduling?.isInputPending === 'function') {
|
||||||
const inputPendingOptions = {
|
const inputPendingOptions = {
|
||||||
@ -1077,13 +1099,17 @@
|
|||||||
return asyncDelay(CONCURRENT_TASK_INTERVAL)
|
return asyncDelay(CONCURRENT_TASK_INTERVAL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const serverCapacity = getServerCapacity()
|
||||||
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
|
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
|
||||||
eventSource.fireEvent(EVENT_IDLE, {})
|
eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity, idle: true})
|
||||||
// Calling idle could result in task being added to queue.
|
// Calling idle could result in task being added to queue.
|
||||||
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
|
if (task_queue.size <= 0 && concurrent_generators.size <= 0) {
|
||||||
return asyncDelay(IDLE_COOLDOWN)
|
return asyncDelay(IDLE_COOLDOWN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (task_queue.size < serverCapacity) {
|
||||||
|
eventSource.fireEvent(EVENT_IDLE, {capacity: serverCapacity - task_queue.size})
|
||||||
|
}
|
||||||
const completedTasks = []
|
const completedTasks = []
|
||||||
for (let [generator, promise] of concurrent_generators.entries()) {
|
for (let [generator, promise] of concurrent_generators.entries()) {
|
||||||
if (promise.isPending) {
|
if (promise.isPending) {
|
||||||
@ -1122,7 +1148,6 @@
|
|||||||
concurrent_generators.set(generator, promise)
|
concurrent_generators.set(generator, promise)
|
||||||
}
|
}
|
||||||
|
|
||||||
const serverCapacity = 2
|
|
||||||
for (let [task, generator] of task_queue.entries()) {
|
for (let [task, generator] of task_queue.entries()) {
|
||||||
const cTsk = completedTasks.find((item) => item.generator === generator)
|
const cTsk = completedTasks.find((item) => item.generator === generator)
|
||||||
if (cTsk?.promise?.rejectReason || task.hasFailed) {
|
if (cTsk?.promise?.rejectReason || task.hasFailed) {
|
||||||
@ -1211,6 +1236,7 @@
|
|||||||
removeEventListener: (...args) => eventSource.removeEventListener(...args),
|
removeEventListener: (...args) => eventSource.removeEventListener(...args),
|
||||||
|
|
||||||
isServerAvailable,
|
isServerAvailable,
|
||||||
|
getServerCapacity,
|
||||||
|
|
||||||
getSystemInfo,
|
getSystemInfo,
|
||||||
getDevices,
|
getDevices,
|
||||||
@ -1228,6 +1254,14 @@
|
|||||||
configurable: false,
|
configurable: false,
|
||||||
get: () => serverState,
|
get: () => serverState,
|
||||||
},
|
},
|
||||||
|
isAvailable: {
|
||||||
|
configurable: false,
|
||||||
|
get: () => isServerAvailable(),
|
||||||
|
},
|
||||||
|
serverCapacity: {
|
||||||
|
configurable: false,
|
||||||
|
get: () => getServerCapacity(),
|
||||||
|
},
|
||||||
sessionId: {
|
sessionId: {
|
||||||
configurable: false,
|
configurable: false,
|
||||||
get: () => sessionId,
|
get: () => sessionId,
|
||||||
|
@ -455,9 +455,10 @@ function makeImage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function onIdle() {
|
function onIdle() {
|
||||||
|
const serverCapacity = SD.serverCapacity
|
||||||
for (const taskEntry of getUncompletedTaskEntries()) {
|
for (const taskEntry of getUncompletedTaskEntries()) {
|
||||||
if (SD.activeTasks.size >= 1) {
|
if (SD.activeTasks.size >= serverCapacity) {
|
||||||
continue
|
break
|
||||||
}
|
}
|
||||||
const task = htmlTaskMap.get(taskEntry)
|
const task = htmlTaskMap.get(taskEntry)
|
||||||
if (!task) {
|
if (!task) {
|
||||||
|
@ -202,12 +202,12 @@ describe('stable-diffusion-ui', function() {
|
|||||||
// Wait for server status to update.
|
// Wait for server status to update.
|
||||||
await SD.waitUntil(() => {
|
await SD.waitUntil(() => {
|
||||||
console.log('Waiting for %s to be received...', renderRequest.task)
|
console.log('Waiting for %s to be received...', renderRequest.task)
|
||||||
return (!SD.serverState.task || SD.serverState.task === renderRequest.task)
|
return (!SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)])
|
||||||
}, 250, 10 * 60 * 1000)
|
}, 250, 10 * 60 * 1000)
|
||||||
// Wait for task to start on server.
|
// Wait for task to start on server.
|
||||||
await SD.waitUntil(() => {
|
await SD.waitUntil(() => {
|
||||||
console.log('Waiting for %s to start...', renderRequest.task)
|
console.log('Waiting for %s to start...', renderRequest.task)
|
||||||
return SD.serverState.task !== renderRequest.task || SD.serverState.session !== 'pending'
|
return !SD.serverState.tasks || SD.serverState.tasks[String(renderRequest.task)] !== 'pending'
|
||||||
}, 250)
|
}, 250)
|
||||||
|
|
||||||
const reader = new SD.ChunkedStreamReader(renderRequest.stream)
|
const reader = new SD.ChunkedStreamReader(renderRequest.stream)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
|
request_id: str = None
|
||||||
session_id: str = "session"
|
session_id: str = "session"
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
negative_prompt: str = ""
|
negative_prompt: str = ""
|
||||||
|
@ -523,15 +523,16 @@ def update_temp_img(req, x_samples, task_temp_images: list):
|
|||||||
del img, x_sample, x_sample_ddim
|
del img, x_sample, x_sample_ddim
|
||||||
# don't delete x_samples, it is used in the code that called this callback
|
# don't delete x_samples, it is used in the code that called this callback
|
||||||
|
|
||||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
|
||||||
task_temp_images[i] = buf
|
task_temp_images[i] = buf
|
||||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
|
||||||
return partial_images
|
return partial_images
|
||||||
|
|
||||||
# Build and return the apropriate generator for do_mk_img
|
# Build and return the apropriate generator for do_mk_img
|
||||||
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
|
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
|
||||||
if not req.stream_progress_updates:
|
if not req.stream_progress_updates:
|
||||||
def empty_callback(x_samples, i): return x_samples
|
def empty_callback(x_samples, i):
|
||||||
|
step_callback()
|
||||||
return empty_callback
|
return empty_callback
|
||||||
|
|
||||||
thread_data.partial_x_samples = None
|
thread_data.partial_x_samples = None
|
||||||
@ -639,11 +640,6 @@ def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, ste
|
|||||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||||
print(f"target t_enc is {t_enc} steps")
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
if req.save_to_disk_path is not None:
|
|
||||||
session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id)
|
|
||||||
else:
|
|
||||||
session_out_path = None
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for n in trange(opt_n_iter, desc="Sampling"):
|
for n in trange(opt_n_iter, desc="Sampling"):
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
@ -37,7 +37,8 @@ class ServerStates:
|
|||||||
|
|
||||||
class RenderTask(): # Task with output queue and completion lock.
|
class RenderTask(): # Task with output queue and completion lock.
|
||||||
def __init__(self, req: Request):
|
def __init__(self, req: Request):
|
||||||
self.request: Request = req # Initial Request
|
req.request_id = id(self)
|
||||||
|
self.request: Request = req # Initial Request
|
||||||
self.response: Any = None # Copy of the last reponse
|
self.response: Any = None # Copy of the last reponse
|
||||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||||
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||||
@ -51,6 +52,22 @@ class RenderTask(): # Task with output queue and completion lock.
|
|||||||
self.buffer_queue.task_done()
|
self.buffer_queue.task_done()
|
||||||
yield res
|
yield res
|
||||||
except queue.Empty as e: yield
|
except queue.Empty as e: yield
|
||||||
|
@property
|
||||||
|
def status(self):
|
||||||
|
if self.lock.locked():
|
||||||
|
return 'running'
|
||||||
|
if isinstance(self.error, StopAsyncIteration):
|
||||||
|
return 'stopped'
|
||||||
|
if self.error:
|
||||||
|
return 'error'
|
||||||
|
if not self.buffer_queue.empty():
|
||||||
|
return 'buffer'
|
||||||
|
if self.response:
|
||||||
|
return 'completed'
|
||||||
|
return 'pending'
|
||||||
|
@property
|
||||||
|
def is_pending(self):
|
||||||
|
return bool(not self.response and not self.error)
|
||||||
|
|
||||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
# defaults from https://huggingface.co/blog/stable_diffusion
|
||||||
class ImageRequest(BaseModel):
|
class ImageRequest(BaseModel):
|
||||||
@ -101,7 +118,7 @@ class FilterRequest(BaseModel):
|
|||||||
output_quality: int = 75
|
output_quality: int = 75
|
||||||
|
|
||||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||||
class TaskCache():
|
class DataCache():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._base = dict()
|
self._base = dict()
|
||||||
self._lock: threading.Lock = threading.Lock()
|
self._lock: threading.Lock = threading.Lock()
|
||||||
@ -110,7 +127,7 @@ class TaskCache():
|
|||||||
def _is_expired(self, timestamp: int) -> bool:
|
def _is_expired(self, timestamp: int) -> bool:
|
||||||
return int(time.time()) >= timestamp
|
return int(time.time()) >= timestamp
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
# Create a list of expired keys to delete
|
# Create a list of expired keys to delete
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@ -120,16 +137,22 @@ class TaskCache():
|
|||||||
to_delete.append(key)
|
to_delete.append(key)
|
||||||
# Remove Items
|
# Remove Items
|
||||||
for key in to_delete:
|
for key in to_delete:
|
||||||
|
(_, val) = self._base[key]
|
||||||
|
if isinstance(val, RenderTask):
|
||||||
|
print(f'RenderTask {key} expired. Data removed.')
|
||||||
|
elif isinstance(val, SessionState):
|
||||||
|
print(f'Session {key} expired. Data removed.')
|
||||||
|
else:
|
||||||
|
print(f'Key {key} expired. Data removed.')
|
||||||
del self._base[key]
|
del self._base[key]
|
||||||
print(f'Session {key} expired. Data removed.')
|
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED)
|
||||||
try: self._base.clear()
|
try: self._base.clear()
|
||||||
finally: self._lock.release()
|
finally: self._lock.release()
|
||||||
def delete(self, key: Hashable) -> bool:
|
def delete(self, key: Hashable) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
if key not in self._base:
|
if key not in self._base:
|
||||||
return False
|
return False
|
||||||
@ -138,7 +161,7 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
if key in self._base:
|
if key in self._base:
|
||||||
_, value = self._base.get(key)
|
_, value = self._base.get(key)
|
||||||
@ -148,7 +171,7 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
self._base[key] = (
|
self._base[key] = (
|
||||||
self._get_ttl_time(ttl), value
|
self._get_ttl_time(ttl), value
|
||||||
@ -162,7 +185,7 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def tryGet(self, key: Hashable) -> Any:
|
def tryGet(self, key: Hashable) -> Any:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
ttl, value = self._base.get(key, (None, None))
|
ttl, value = self._base.get(key, (None, None))
|
||||||
if ttl is not None and self._is_expired(ttl):
|
if ttl is not None and self._is_expired(ttl):
|
||||||
@ -181,11 +204,37 @@ current_model_path = None
|
|||||||
current_vae_path = None
|
current_vae_path = None
|
||||||
current_hypernetwork_path = None
|
current_hypernetwork_path = None
|
||||||
tasks_queue = []
|
tasks_queue = []
|
||||||
task_cache = TaskCache()
|
session_cache = DataCache()
|
||||||
|
task_cache = DataCache()
|
||||||
default_model_to_load = None
|
default_model_to_load = None
|
||||||
default_vae_to_load = None
|
default_vae_to_load = None
|
||||||
default_hypernetwork_to_load = None
|
default_hypernetwork_to_load = None
|
||||||
weak_thread_data = weakref.WeakKeyDictionary()
|
weak_thread_data = weakref.WeakKeyDictionary()
|
||||||
|
idle_event: threading.Event = threading.Event()
|
||||||
|
|
||||||
|
class SessionState():
|
||||||
|
def __init__(self, id: str):
|
||||||
|
self._id = id
|
||||||
|
self._tasks_ids = []
|
||||||
|
@property
|
||||||
|
def id(self):
|
||||||
|
return self._id
|
||||||
|
@property
|
||||||
|
def tasks(self):
|
||||||
|
tasks = []
|
||||||
|
for task_id in self._tasks_ids:
|
||||||
|
task = task_cache.tryGet(task_id)
|
||||||
|
if task:
|
||||||
|
tasks.append(task)
|
||||||
|
return tasks
|
||||||
|
def put(self, task, ttl=TASK_TTL):
|
||||||
|
task_id = id(task)
|
||||||
|
self._tasks_ids.append(task_id)
|
||||||
|
if not task_cache.put(task_id, task, ttl):
|
||||||
|
return False
|
||||||
|
while len(self._tasks_ids) > len(render_threads) * 2:
|
||||||
|
self._tasks_ids.pop(0)
|
||||||
|
return True
|
||||||
|
|
||||||
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
|
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
|
||||||
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
|
||||||
@ -268,6 +317,7 @@ def thread_render(device):
|
|||||||
preload_model()
|
preload_model()
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
while True:
|
while True:
|
||||||
|
session_cache.clean()
|
||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||||
print(f'Shutting down thread for device {runtime.thread_data.device}')
|
print(f'Shutting down thread for device {runtime.thread_data.device}')
|
||||||
@ -279,7 +329,8 @@ def thread_render(device):
|
|||||||
return
|
return
|
||||||
task = thread_get_next_task()
|
task = thread_get_next_task()
|
||||||
if task is None:
|
if task is None:
|
||||||
time.sleep(0.05)
|
idle_event.clear()
|
||||||
|
idle_event.wait(timeout=1)
|
||||||
continue
|
continue
|
||||||
if task.error is not None:
|
if task.error is not None:
|
||||||
print(task.error)
|
print(task.error)
|
||||||
@ -314,10 +365,11 @@ def thread_render(device):
|
|||||||
current_state_error = None
|
current_state_error = None
|
||||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||||
|
|
||||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
|
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||||
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
|
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task.error = e
|
task.error = e
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -325,7 +377,8 @@ def thread_render(device):
|
|||||||
finally:
|
finally:
|
||||||
# Task completed
|
# Task completed
|
||||||
task.lock.release()
|
task.lock.release()
|
||||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
|
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
if isinstance(task.error, StopAsyncIteration):
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||||
elif task.error is not None:
|
elif task.error is not None:
|
||||||
@ -334,12 +387,21 @@ def thread_render(device):
|
|||||||
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.')
|
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.')
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
def get_cached_task(session_id:str, update_ttl:bool=False):
|
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||||
# By calling keep before tryGet, wont discard if was expired.
|
# By calling keep before tryGet, wont discard if was expired.
|
||||||
if update_ttl and not task_cache.keep(session_id, TASK_TTL):
|
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
|
||||||
# Failed to keep task, already gone.
|
# Failed to keep task, already gone.
|
||||||
return None
|
return None
|
||||||
return task_cache.tryGet(session_id)
|
return task_cache.tryGet(task_id)
|
||||||
|
|
||||||
|
def get_cached_session(session_id:str, update_ttl:bool=False):
|
||||||
|
if update_ttl:
|
||||||
|
session_cache.keep(session_id, TASK_TTL)
|
||||||
|
session = session_cache.tryGet(session_id)
|
||||||
|
if not session:
|
||||||
|
session = SessionState(session_id)
|
||||||
|
session_cache.put(session_id, session, TASK_TTL)
|
||||||
|
return session
|
||||||
|
|
||||||
def get_devices():
|
def get_devices():
|
||||||
devices = {
|
devices = {
|
||||||
@ -486,14 +548,16 @@ def shutdown_event(): # Signal render thread to close on shutdown
|
|||||||
current_state_error = SystemExit('Application shutting down.')
|
current_state_error = SystemExit('Application shutting down.')
|
||||||
|
|
||||||
def render(req : ImageRequest):
|
def render(req : ImageRequest):
|
||||||
if is_alive() <= 0: # Render thread is dead
|
current_thread_count = is_alive()
|
||||||
|
if current_thread_count <= 0: # Render thread is dead
|
||||||
raise ChildProcessError('Rendering thread has died.')
|
raise ChildProcessError('Rendering thread has died.')
|
||||||
|
|
||||||
# Alive, check if task in cache
|
# Alive, check if task in cache
|
||||||
task = task_cache.tryGet(req.session_id)
|
session = get_cached_session(req.session_id, update_ttl=True)
|
||||||
if task and not task.response and not task.error and not task.lock.locked():
|
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||||
# Unstarted task pending, deny queueing more than one.
|
if current_thread_count < len(pending_tasks):
|
||||||
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
|
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||||
#
|
|
||||||
from . import runtime
|
from . import runtime
|
||||||
r = Request()
|
r = Request()
|
||||||
r.session_id = req.session_id
|
r.session_id = req.session_id
|
||||||
@ -530,13 +594,13 @@ def render(req : ImageRequest):
|
|||||||
r.stream_image_progress = False
|
r.stream_image_progress = False
|
||||||
|
|
||||||
new_task = RenderTask(r)
|
new_task = RenderTask(r)
|
||||||
|
if session.put(new_task, TASK_TTL):
|
||||||
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
|
||||||
# Use twice the normal timeout for adding user requests.
|
# Use twice the normal timeout for adding user requests.
|
||||||
# Tries to force task_cache.put to fail before tasks_queue.put would.
|
# Tries to force session.put to fail before tasks_queue.put would.
|
||||||
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||||
try:
|
try:
|
||||||
tasks_queue.append(new_task)
|
tasks_queue.append(new_task)
|
||||||
|
idle_event.set()
|
||||||
return new_task
|
return new_task
|
||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
|
58
ui/server.py
58
ui/server.py
@ -49,14 +49,12 @@ from fastapi.staticfiles import StaticFiles
|
|||||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
#import queue, threading, time
|
|
||||||
from typing import Any, Generator, Hashable, List, Optional, Union
|
from typing import Any, Generator, Hashable, List, Optional, Union
|
||||||
|
|
||||||
from sd_internal import Request, Response, task_manager
|
from sd_internal import Request, Response, task_manager
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
modifiers_cache = None
|
|
||||||
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
||||||
|
|
||||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||||
@ -354,21 +352,8 @@ def ping(session_id:str=None):
|
|||||||
# Alive
|
# Alive
|
||||||
response = {'status': str(task_manager.current_state)}
|
response = {'status': str(task_manager.current_state)}
|
||||||
if session_id:
|
if session_id:
|
||||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||||
if task:
|
response['tasks'] = {id(t): t.status for t in session.tasks}
|
||||||
response['task'] = id(task)
|
|
||||||
if task.lock.locked():
|
|
||||||
response['session'] = 'running'
|
|
||||||
elif isinstance(task.error, StopAsyncIteration):
|
|
||||||
response['session'] = 'stopped'
|
|
||||||
elif task.error:
|
|
||||||
response['session'] = 'error'
|
|
||||||
elif not task.buffer_queue.empty():
|
|
||||||
response['session'] = 'buffer'
|
|
||||||
elif task.response:
|
|
||||||
response['session'] = 'completed'
|
|
||||||
else:
|
|
||||||
response['session'] = 'pending'
|
|
||||||
response['devices'] = task_manager.get_devices()
|
response['devices'] = task_manager.get_devices()
|
||||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
@ -408,23 +393,25 @@ def render(req : task_manager.ImageRequest):
|
|||||||
response = {
|
response = {
|
||||||
'status': str(task_manager.current_state),
|
'status': str(task_manager.current_state),
|
||||||
'queue': len(task_manager.tasks_queue),
|
'queue': len(task_manager.tasks_queue),
|
||||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
'stream': f'/image/stream/{id(new_task)}',
|
||||||
'task': id(new_task)
|
'task': id(new_task)
|
||||||
}
|
}
|
||||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
except ChildProcessError as e: # Render thread is dead
|
except ChildProcessError as e: # Render thread is dead
|
||||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||||
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
|
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||||
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
|
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.get('/image/stream/{session_id:str}/{task_id:int}')
|
@app.get('/image/stream/{task_id:int}')
|
||||||
def stream(session_id:str, task_id:int):
|
def stream(task_id:int):
|
||||||
#TODO Move to WebSockets ??
|
#TODO Move to WebSockets ??
|
||||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||||
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
|
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
|
||||||
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||||
if task.buffer_queue.empty() and not task.lock.locked():
|
if task.buffer_queue.empty() and not task.lock.locked():
|
||||||
if task.response:
|
if task.response:
|
||||||
#print(f'Session {session_id} sending cached response')
|
#print(f'Session {session_id} sending cached response')
|
||||||
@ -434,22 +421,23 @@ def stream(session_id:str, task_id:int):
|
|||||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||||
|
|
||||||
@app.get('/image/stop')
|
@app.get('/image/stop')
|
||||||
def stop(session_id:str=None):
|
def stop(task: int):
|
||||||
if not session_id:
|
if not task:
|
||||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||||
task_manager.current_state_error = StopAsyncIteration('')
|
task_manager.current_state_error = StopAsyncIteration('')
|
||||||
return {'OK'}
|
return {'OK'}
|
||||||
task = task_manager.get_cached_task(session_id, update_ttl=False)
|
task_id = task
|
||||||
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
|
task = task_manager.get_cached_task(task_id, update_ttl=False)
|
||||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
|
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
|
||||||
task.error = StopAsyncIteration('')
|
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
|
||||||
|
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
|
||||||
return {'OK'}
|
return {'OK'}
|
||||||
|
|
||||||
@app.get('/image/tmp/{session_id}/{img_id:int}')
|
@app.get('/image/tmp/{task_id:int}/{img_id:int}')
|
||||||
def get_image(session_id, img_id):
|
def get_image(task_id: int, img_id: int):
|
||||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||||
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
|
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
|
||||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||||
try:
|
try:
|
||||||
img_data = task.temp_images[img_id]
|
img_data = task.temp_images[img_id]
|
||||||
|
Loading…
Reference in New Issue
Block a user