mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-14 22:48:20 +02:00
Render queue first draft
This commit is contained in:
parent
4b88cfa51a
commit
1ec9d986bb
184
ui/media/main.js
184
ui/media/main.js
@ -117,7 +117,7 @@ maskResetButton.innerHTML = 'Clear'
|
||||
maskResetButton.style.fontWeight = 'normal'
|
||||
maskResetButton.style.fontSize = '10pt'
|
||||
|
||||
let serverStatus = 'offline'
|
||||
let serverState = {'status': 'Offline'}
|
||||
let activeTags = []
|
||||
let modifiers = []
|
||||
let lastPromptUsed = ''
|
||||
@ -212,21 +212,38 @@ function getOutputFormat() {
|
||||
}
|
||||
|
||||
function setStatus(statusType, msg, msgType) {
|
||||
if (statusType !== 'server') {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if (msgType == 'error') {
|
||||
// msg = '<span style="color: red">' + msg + '<span>'
|
||||
serverStatusColor.style.color = 'red'
|
||||
serverStatusMsg.style.color = 'red'
|
||||
serverStatusMsg.innerText = 'Stable Diffusion has stopped'
|
||||
} else if (msgType == 'success') {
|
||||
// msg = '<span style="color: green">' + msg + '<span>'
|
||||
serverStatusColor.style.color = 'green'
|
||||
serverStatusMsg.style.color = 'green'
|
||||
serverStatusMsg.innerText = 'Stable Diffusion is ready'
|
||||
serverStatus = 'online'
|
||||
function setServerStatus(msgType, msg) {
|
||||
switch(msgType) {
|
||||
case 'online':
|
||||
serverStatusColor.style.color = 'green'
|
||||
serverStatusMsg.style.color = 'green'
|
||||
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
|
||||
break
|
||||
case 'busy':
|
||||
serverStatusColor.style.color = 'yellow'
|
||||
serverStatusMsg.style.color = 'yellow'
|
||||
serverStatusMsg.innerText = 'Stable Diffusion is ' + msg
|
||||
break
|
||||
case 'error':
|
||||
serverStatusColor.style.color = 'red'
|
||||
serverStatusMsg.style.color = 'red'
|
||||
serverStatusMsg.innerText = 'Stable Diffusion has stopped'
|
||||
break
|
||||
}
|
||||
}
|
||||
function isServerAvailable() {
|
||||
if (typeof serverState !== 'object') {
|
||||
return false
|
||||
}
|
||||
switch (serverState.status) {
|
||||
case 'LoadingModel':
|
||||
case 'Rendering':
|
||||
case 'Online':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@ -250,6 +267,11 @@ function logError(msg, res, outputMsg) {
|
||||
console.log('request error', res)
|
||||
setStatus('request', 'error', 'error')
|
||||
}
|
||||
function asyncDelay(timeout) {
|
||||
return new Promise(function(resolve, reject) {
|
||||
setTimeout(resolve, timeout, true)
|
||||
})
|
||||
}
|
||||
|
||||
function playSound() {
|
||||
const audio = new Audio('/media/ding.mp3')
|
||||
@ -259,16 +281,34 @@ function playSound() {
|
||||
|
||||
async function healthCheck() {
|
||||
try {
|
||||
let res = await fetch('/ping')
|
||||
res = await res.json()
|
||||
|
||||
if (res[0] == 'OK') {
|
||||
setStatus('server', 'online', 'success')
|
||||
let res = undefined
|
||||
if (sessionId) {
|
||||
res = await fetch('/ping?session_id=' + sessionId)
|
||||
} else {
|
||||
setStatus('server', 'offline', 'error')
|
||||
res = await fetch('/ping')
|
||||
}
|
||||
serverState = await res.json()
|
||||
// Set status
|
||||
switch(serverState.status) {
|
||||
case 'Init':
|
||||
// Wait for init to complete before updating status.
|
||||
break
|
||||
case 'Online':
|
||||
setServerStatus('online', 'ready')
|
||||
break
|
||||
case 'LoadingModel':
|
||||
setServerStatus('busy', 'loading model')
|
||||
break
|
||||
case 'Rendering':
|
||||
setServerStatus('busy', 'rendering')
|
||||
break
|
||||
default: // Unavailable
|
||||
setServerStatus('error', serverState.status.toLowerCase())
|
||||
break
|
||||
}
|
||||
serverState.time = Date.now()
|
||||
} catch (e) {
|
||||
setStatus('server', 'offline', 'error')
|
||||
setServerStatus('error', 'offline')
|
||||
}
|
||||
}
|
||||
function resizeInpaintingEditor() {
|
||||
@ -311,7 +351,7 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
if(typeof res != 'object') return
|
||||
res.output.reverse()
|
||||
res.output.forEach((result, index) => {
|
||||
const imageData = result?.data || result?.path + '?t=' + new Date().getTime()
|
||||
const imageData = result?.data || result?.path + '?t=' + Date.now()
|
||||
const imageWidth = reqBody.width
|
||||
const imageHeight = reqBody.height
|
||||
if (!imageData.includes('/')) {
|
||||
@ -409,8 +449,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) {
|
||||
}
|
||||
function getStartNewTaskHandler(reqBody, imageItemElem, mode) {
|
||||
return function() {
|
||||
if (serverStatus !== 'online') {
|
||||
alert('The server is still starting up..')
|
||||
if (!isServerAvailable()) {
|
||||
alert('The server is not available.')
|
||||
return
|
||||
}
|
||||
const imageElem = imageItemElem.querySelector('img')
|
||||
@ -473,37 +513,68 @@ async function doMakeImage(task) {
|
||||
const progressBar = task['progressBar']
|
||||
|
||||
let res = undefined
|
||||
let stepUpdate = undefined
|
||||
try {
|
||||
res = await fetch('/image', {
|
||||
method: 'POST',
|
||||
const lastTask = serverState.task
|
||||
let renderRequest = undefined
|
||||
do {
|
||||
res = await fetch('/render', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(reqBody)
|
||||
})
|
||||
renderRequest = await res.json()
|
||||
// status_code 503, already a task running.
|
||||
} while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000))
|
||||
if (typeof renderRequest?.stream !== 'string') {
|
||||
console.log('Endpoint response: ', renderRequest)
|
||||
throw new Error('Endpoint response does not contains a response stream url.')
|
||||
}
|
||||
task['taskStatusLabel'].innerText = "Busy/Waiting"
|
||||
do { // Wait for server status to update.
|
||||
await asyncDelay(250)
|
||||
if (!isServerAvailable()) {
|
||||
throw new Error('Connexion with server lost.')
|
||||
}
|
||||
} while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task)
|
||||
if (serverState.session !== 'pending' && serverState.session !== 'running') {
|
||||
throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined')
|
||||
}
|
||||
do { // Wait for task to start on server.
|
||||
await asyncDelay(1500)
|
||||
} while (serverState?.session === 'pending')
|
||||
|
||||
// Task started!
|
||||
res = await fetch(renderRequest.stream, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(reqBody)
|
||||
})
|
||||
|
||||
task['taskStatusLabel'].innerText = "Processing"
|
||||
task['taskStatusLabel'].classList.add('activeTaskLabel')
|
||||
task['taskStatusLabel'].classList.remove('waitingTaskLabel')
|
||||
|
||||
let stepUpdate = undefined
|
||||
let reader = res.body.getReader()
|
||||
let textDecoder = new TextDecoder()
|
||||
let finalJSON = ''
|
||||
let prevTime = -1
|
||||
let readComplete = false
|
||||
while (true) {
|
||||
let t = new Date().getTime()
|
||||
|
||||
while (!readComplete || finalJSON.length > 0) {
|
||||
let t = Date.now()
|
||||
let jsonStr = ''
|
||||
if (!readComplete) {
|
||||
const {value, done} = await reader.read()
|
||||
if (done) {
|
||||
readComplete = true
|
||||
}
|
||||
if (done && finalJSON.length <= 0 && !value) {
|
||||
break
|
||||
}
|
||||
if (value) {
|
||||
jsonStr = textDecoder.decode(value)
|
||||
}
|
||||
}
|
||||
stepUpdate = undefined
|
||||
try {
|
||||
// hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot.
|
||||
// this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste...
|
||||
@ -537,9 +608,6 @@ async function doMakeImage(task) {
|
||||
throw e
|
||||
}
|
||||
}
|
||||
if (readComplete && finalJSON.length <= 0) {
|
||||
break
|
||||
}
|
||||
if (typeof stepUpdate === 'object' && 'step' in stepUpdate) {
|
||||
let batchSize = stepUpdate.total_steps
|
||||
let overallStepCount = stepUpdate.step + task.batchesDone * batchSize
|
||||
@ -564,6 +632,23 @@ async function doMakeImage(task) {
|
||||
showImages(reqBody, stepUpdate, outputContainer, true)
|
||||
}
|
||||
}
|
||||
if (stepUpdate?.status) {
|
||||
break
|
||||
}
|
||||
if (readComplete && finalJSON.length <= 0) {
|
||||
if (res.status === 200) {
|
||||
await asyncDelay(5000)
|
||||
res = await fetch(renderRequest.stream, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
})
|
||||
reader = res.body.getReader()
|
||||
readComplete = false
|
||||
} else {
|
||||
console.log('Stream stopped: ', res)
|
||||
}
|
||||
}
|
||||
prevTime = t
|
||||
}
|
||||
|
||||
@ -580,27 +665,28 @@ async function doMakeImage(task) {
|
||||
3. Try generating a smaller image.<br/>`
|
||||
}
|
||||
} else {
|
||||
msg = `Unexpected Read Error:<br/><pre>StepUpdate:${JSON.stringify(stepUpdate, undefined, 4)}</pre>`
|
||||
msg = `Unexpected Read Error:<br/><pre>StepUpdate: ${JSON.stringify(stepUpdate, undefined, 4)}</pre>`
|
||||
}
|
||||
logError(msg, res, outputMsg)
|
||||
return false
|
||||
}
|
||||
if (typeof stepUpdate !== 'object' || !res || res.status != 200) {
|
||||
if (serverStatus !== 'online') {
|
||||
if (!isServerAvailable()) {
|
||||
logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg)
|
||||
} else if (typeof res === 'object') {
|
||||
let msg = 'Stable Diffusion had an error reading the response: '
|
||||
try { // 'Response': body stream already read
|
||||
msg += 'Read: ' + await res.text()
|
||||
} catch(e) {
|
||||
msg += 'No error response. '
|
||||
msg += 'Unexpected end of stream. '
|
||||
}
|
||||
if (finalJSON) {
|
||||
msg += 'Buffered data: ' + finalJSON
|
||||
}
|
||||
logError(msg, res, outputMsg)
|
||||
} else {
|
||||
msg = `Unexpected Read Error:<br/><pre>Response:${res}<br/>StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
|
||||
let msg = `Unexpected Read Error:<br/><pre>Response: ${res}<br/>StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}</pre>`
|
||||
logError(msg, res, outputMsg)
|
||||
}
|
||||
progressBar.style.display = 'none'
|
||||
return false
|
||||
@ -654,8 +740,8 @@ async function checkTasks() {
|
||||
|
||||
task.isProcessing = true
|
||||
task['stopTask'].innerHTML = '<i class="fa-solid fa-circle-stop"></i> Stop'
|
||||
task['taskStatusLabel'].innerText = "Processing"
|
||||
task['taskStatusLabel'].className += " activeTaskLabel"
|
||||
task['taskStatusLabel'].innerText = "Starting"
|
||||
task['taskStatusLabel'].classList.add('waitingTaskLabel')
|
||||
|
||||
const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1))
|
||||
const startSeed = task.reqBody.seed || task.seed
|
||||
@ -780,8 +866,8 @@ function getCurrentUserRequest() {
|
||||
}
|
||||
|
||||
function makeImage() {
|
||||
if (serverStatus !== 'online') {
|
||||
alert('The server is still starting up..')
|
||||
if (!isServerAvailable()) {
|
||||
alert('The server is not available.')
|
||||
return
|
||||
}
|
||||
const taskTemplate = getCurrentUserRequest()
|
||||
@ -834,7 +920,7 @@ function createTask(task) {
|
||||
if (task['isProcessing']) {
|
||||
task.isProcessing = false
|
||||
try {
|
||||
let res = await fetch('/image/stop')
|
||||
let res = await fetch('/image/stop?session_id=' + sessionId)
|
||||
} catch (e) {
|
||||
console.log(e)
|
||||
}
|
||||
@ -1094,9 +1180,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider)
|
||||
updatePromptStrength()
|
||||
|
||||
useBetaChannelField.addEventListener('click', async function(e) {
|
||||
if (serverStatus !== 'online') {
|
||||
if (!isServerAvailable()) {
|
||||
// logError('The server is still starting up..')
|
||||
alert('The server is still starting up..')
|
||||
alert('The server is not available.')
|
||||
e.preventDefault()
|
||||
return false
|
||||
}
|
||||
|
325
ui/server.py
325
ui/server.py
@ -14,28 +14,55 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 * 1000 # Discard last session's task timeout
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, StreamingResponse
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, Optional, Union
|
||||
|
||||
from sd_internal import Request, Response
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
model_loaded = False
|
||||
model_is_loading = False
|
||||
|
||||
modifiers_cache = None
|
||||
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
|
||||
|
||||
# don't show access log entries for URLs that start with the given prefix
|
||||
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails']
|
||||
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
||||
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
def __repr__(self): return self.__qualname__
|
||||
def __str__(self): return self.__name__
|
||||
class Symbol(metaclass=SymbolClass): pass
|
||||
|
||||
class ServerStates:
|
||||
class Init(Symbol): pass
|
||||
class LoadingModel(Symbol): pass
|
||||
class Online(Symbol): pass
|
||||
class Rendering(Symbol): pass
|
||||
class Unavailable(Symbol): pass
|
||||
|
||||
class RenderTask(): # Task with output queue and completion lock.
|
||||
def __init__(self, req: Request):
|
||||
self.request: Request = req # Initial Request
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.temp_images:[] = [None] * req.num_outputs
|
||||
self.error: Exception = None
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
|
||||
current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_model_path = None
|
||||
tasks_queue = queue.Queue()
|
||||
|
||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
||||
class ImageRequest(BaseModel):
|
||||
session_id: str = "session"
|
||||
@ -65,38 +92,152 @@ class ImageRequest(BaseModel):
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class TaskCache():
|
||||
def __init__(self):
|
||||
self._base = dict()
|
||||
def _get_ttl_time(self, ttl: int) -> int:
|
||||
return int(time.time()) + ttl
|
||||
def _is_expired(self, timestamp: int) -> bool:
|
||||
return int(time.time()) >= timestamp
|
||||
def clean(self) -> None:
|
||||
for key in self._base:
|
||||
ttl, _ = self._base[key]
|
||||
if self._is_expired(ttl):
|
||||
del self._base[key]
|
||||
def clear(self) -> None:
|
||||
self._base.clear()
|
||||
def delete(self, key: Hashable) -> bool:
|
||||
if key not in self._base:
|
||||
return False
|
||||
del self._base[key]
|
||||
return True
|
||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||
if key in self._base:
|
||||
_, value = self._base.get(key)
|
||||
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||
return True
|
||||
return False
|
||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||
try:
|
||||
self._base[key] = (
|
||||
self._get_ttl_time(ttl), value
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
def tryGet(self, key: Hashable) -> Any:
|
||||
ttl, value = self._base.get(key, (None, None))
|
||||
if ttl is not None and self._is_expired(ttl):
|
||||
self.delete(key)
|
||||
return None
|
||||
return value
|
||||
|
||||
task_cache = TaskCache()
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = "main"
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers)
|
||||
|
||||
@app.get('/ping')
|
||||
async def ping():
|
||||
global model_loaded, model_is_loading
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
|
||||
def preload_model(file_path=None):
|
||||
global current_state, current_state_error, current_model_path
|
||||
if file_path == None:
|
||||
file_path = get_initial_model_to_load()
|
||||
if file_path == current_model_path:
|
||||
return
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
if model_loaded:
|
||||
return {'OK'}
|
||||
|
||||
if model_is_loading:
|
||||
return {'ERROR'}
|
||||
|
||||
model_is_loading = True
|
||||
|
||||
from sd_internal import runtime
|
||||
|
||||
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
|
||||
|
||||
model_loaded = True
|
||||
model_is_loading = False
|
||||
|
||||
return {'OK'}
|
||||
runtime.load_model_ckpt(ckpt_to_use=file_path)
|
||||
current_model_path = file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
current_model_path = None
|
||||
current_state_error = e
|
||||
current_state = ServerStates.Unavailable
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def thread_render():
|
||||
global current_state, current_state_error
|
||||
from sd_internal import runtime
|
||||
current_state = ServerStates.Online
|
||||
preload_model()
|
||||
while True:
|
||||
task_cache.clean()
|
||||
task = None
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
return
|
||||
try:
|
||||
task = tasks_queue.get(timeout=1)
|
||||
except queue.Empty as e:
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
return
|
||||
else: continue
|
||||
#if current_model_path != task.request.use_stable_diffusion_model:
|
||||
# preload_model(task.request.use_stable_diffusion_model)
|
||||
if current_state_error:
|
||||
task.error = current_state_error
|
||||
continue
|
||||
print(f'Session {task.request.session_id} starting task {id(task)}')
|
||||
current_state = ServerStates.Rendering
|
||||
try:
|
||||
task.lock.acquire(blocking=False)
|
||||
res = runtime.mk_img(task.request)
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
task.lock.release()
|
||||
tasks_queue.task_done()
|
||||
print(traceback.format_exc())
|
||||
continue
|
||||
dataQueue = None
|
||||
if task.request.stream_progress_updates:
|
||||
dataQueue = task.buffer_queue
|
||||
for result in res:
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
runtime.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
if dataQueue:
|
||||
dataQueue.put(result)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
task.response = result
|
||||
if 'output' in result:
|
||||
for out_obj in result['output']:
|
||||
if 'path' in out_obj:
|
||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
|
||||
elif 'data' in out_obj:
|
||||
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
# Task completed
|
||||
task.lock.release()
|
||||
tasks_queue.task_done()
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||
elif task.error is not None:
|
||||
print(f'Session {task.request.session_id} task {id(task)} failed!')
|
||||
else:
|
||||
print(f'Session {task.request.session_id} task {id(task)} completed.')
|
||||
current_state = ServerStates.Online
|
||||
# Start Rendering Thread
|
||||
render_thread = threading.Thread(target=thread_render)
|
||||
render_thread.daemon = True
|
||||
render_thread.start()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
global current_state_error
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
# needs to support the legacy installations
|
||||
def get_initial_model_to_load():
|
||||
@ -126,7 +267,6 @@ def resolve_model_to_use(model_name):
|
||||
model_path = legacy_model_path
|
||||
else:
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
|
||||
return model_path
|
||||
|
||||
def save_model_to_config(model_name):
|
||||
@ -135,13 +275,42 @@ def save_model_to_config(model_name):
|
||||
config['model'] = {}
|
||||
|
||||
config['model']['stable-diffusion'] = model_name
|
||||
|
||||
setConfig(config)
|
||||
|
||||
@app.post('/image')
|
||||
def image(req : ImageRequest):
|
||||
from sd_internal import runtime
|
||||
@app.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
if current_state_error or not render_thread.is_alive(): # Render thread is dead.
|
||||
return HTTPException(status_code=500, detail=str(current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(current_state)}
|
||||
if session_id:
|
||||
task = task_cache.tryGet(session_id)
|
||||
if task:
|
||||
response['task'] = id(task)
|
||||
if task.lock.locked():
|
||||
response['session'] = 'running'
|
||||
elif isinstance(task.error, StopAsyncIteration):
|
||||
response['session'] = 'stopped'
|
||||
elif task.error:
|
||||
response['session'] = 'error'
|
||||
elif not task.buffer_queue.empty():
|
||||
response['session'] = 'buffer'
|
||||
elif task.response:
|
||||
response['session'] = 'completed'
|
||||
else:
|
||||
response['session'] = 'pending'
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
@app.post('/render')
|
||||
def render(req : ImageRequest):
|
||||
if not render_thread.is_alive(): # Render thread is dead
|
||||
return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||
# Alive, check if task in cache
|
||||
task = task_cache.tryGet(req.session_id)
|
||||
if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one.
|
||||
return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
|
||||
#
|
||||
from sd_internal import runtime
|
||||
r = Request()
|
||||
r.session_id = req.session_id
|
||||
r.prompt = req.prompt
|
||||
@ -173,45 +342,70 @@ def image(req : ImageRequest):
|
||||
|
||||
save_model_to_config(req.use_stable_diffusion_model)
|
||||
|
||||
if not req.stream_progress_updates:
|
||||
r.stream_image_progress = False
|
||||
|
||||
new_task = RenderTask(r)
|
||||
task_cache.put(r.session_id, new_task, TASK_TTL)
|
||||
tasks_queue.put(new_task)
|
||||
|
||||
response = {
|
||||
'status': str(current_state),
|
||||
'queue': tasks_queue.qsize(),
|
||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
async def read_data_generator(data:queue.Queue, lock:threading.Lock):
|
||||
try:
|
||||
if not req.stream_progress_updates:
|
||||
r.stream_image_progress = False
|
||||
while not data.empty():
|
||||
res = data.get(block=False)
|
||||
data.task_done()
|
||||
yield res
|
||||
except queue.Empty as e: yield
|
||||
|
||||
res = runtime.mk_img(r)
|
||||
|
||||
if req.stream_progress_updates:
|
||||
return StreamingResponse(res, media_type='application/json')
|
||||
else: # compatibility mode: buffer the streaming responses, and return the last one
|
||||
last_result = None
|
||||
|
||||
for result in res:
|
||||
last_result = result
|
||||
|
||||
return json.loads(last_result)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
@app.get('/image/stream/{session_id:str}/{task_id:int}')
|
||||
def stream(session_id:str, task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
task = task_cache.tryGet(session_id)
|
||||
if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
|
||||
if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if task.buffer_queue.empty() and not task.lock.locked():
|
||||
if task.response:
|
||||
#print(f'Session {session_id} sending cached response')
|
||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||
return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json')
|
||||
|
||||
@app.get('/image/stop')
|
||||
def stop():
|
||||
try:
|
||||
if model_is_loading:
|
||||
return {'ERROR'}
|
||||
|
||||
from sd_internal import runtime
|
||||
runtime.stop_processing = True
|
||||
|
||||
def stop(session_id:str=None):
|
||||
if not session_id:
|
||||
if current_state == ServerStates.Online or current_state == ServerStates.Unavailable:
|
||||
return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||
global current_state_error
|
||||
current_state_error = StopAsyncIteration()
|
||||
return {'OK'}
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
task = task_cache.tryGet(session_id)
|
||||
if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
|
||||
@app.get('/image/tmp/{session_id}/{img_id}')
|
||||
@app.get('/image/tmp/{session_id}/{img_id:int}')
|
||||
def get_image(session_id, img_id):
|
||||
from sd_internal import runtime
|
||||
buf = runtime.temp_images[session_id + '/' + img_id]
|
||||
buf.seek(0)
|
||||
return StreamingResponse(buf, media_type='image/jpeg')
|
||||
task = task_cache.tryGet(session_id)
|
||||
if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
|
||||
if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||
try:
|
||||
img_data = task.temp_images[img_id]
|
||||
if isinstance(img_data, str):
|
||||
return img_data
|
||||
img_data.seek(0)
|
||||
return StreamingResponse(img_data, media_type='image/jpeg')
|
||||
except KeyError as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
@ -257,7 +451,6 @@ def getConfig(default_val={}):
|
||||
def setConfig(config):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
|
||||
with open(config_json_path, 'w') as f:
|
||||
return json.dump(config, f)
|
||||
except:
|
||||
@ -316,9 +509,7 @@ class LogSuppressFilter(logging.Filter):
|
||||
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
|
||||
if path.find(prefix) != -1:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
|
||||
# start the browser ui
|
||||
|
Loading…
Reference in New Issue
Block a user