diff --git a/ui/media/main.css b/ui/media/main.css index b2c984fb..33b9027b 100644 --- a/ui/media/main.css +++ b/ui/media/main.css @@ -481,6 +481,11 @@ img { border: 1px solid rgb(0, 75, 19); color:rgb(204, 255, 217) } +.waitingTaskLabel { + background:rgb(90, 90, 0); + border: 1px solid rgb(0, 75, 19); + color:rgb(255, 255, 204) +} .secondaryButton { background: rgb(132, 8, 0); border: 1px solid rgb(122, 29, 0); diff --git a/ui/media/main.js b/ui/media/main.js index da28d7bb..3afb8ce9 100644 --- a/ui/media/main.js +++ b/ui/media/main.js @@ -20,7 +20,7 @@ const INPAINTING_EDITOR_SIZE = 450 const IMAGE_REGEX = new RegExp('data:image/[A-Za-z]+;base64') -let sessionId = new Date().getTime() +let sessionId = Date.now() let promptField = document.querySelector('#prompt') let promptsFromFileSelector = document.querySelector('#prompt_from_file') @@ -122,7 +122,7 @@ maskResetButton.innerHTML = 'Clear' maskResetButton.style.fontWeight = 'normal' maskResetButton.style.fontSize = '10pt' -let serverStatus = 'offline' +let serverState = {'status': 'Offline', 'time': Date.now()} let activeTags = [] let modifiers = [] let lastPromptUsed = '' @@ -225,21 +225,38 @@ function getOutputFormat() { } function setStatus(statusType, msg, msgType) { - if (statusType !== 'server') { - return - } +} - if (msgType == 'error') { - // msg = '' + msg + '' - serverStatusColor.style.color = 'red' - serverStatusMsg.style.color = 'red' - serverStatusMsg.innerText = 'Stable Diffusion has stopped' - } else if (msgType == 'success') { - // msg = '' + msg + '' - serverStatusColor.style.color = 'green' - serverStatusMsg.style.color = 'green' - serverStatusMsg.innerText = 'Stable Diffusion is ready' - serverStatus = 'online' +function setServerStatus(msgType, msg) { + switch(msgType) { + case 'online': + serverStatusColor.style.color = 'green' + serverStatusMsg.style.color = 'green' + serverStatusMsg.innerText = 'Stable Diffusion is ' + msg + break + case 'busy': + serverStatusColor.style.color = 'yellow' + serverStatusMsg.style.color = 'yellow' + serverStatusMsg.innerText = 'Stable Diffusion is ' + msg + break + case 'error': + serverStatusColor.style.color = 'red' + serverStatusMsg.style.color = 'red' + serverStatusMsg.innerText = 'Stable Diffusion has stopped' + break + } +} +function isServerAvailable() { + if (typeof serverState !== 'object') { + return false + } + switch (serverState.status) { + case 'LoadingModel': + case 'Rendering': + case 'Online': + return true + default: + return false } } @@ -263,6 +280,11 @@ function logError(msg, res, outputMsg) { console.log('request error', res) setStatus('request', 'error', 'error') } +function asyncDelay(timeout) { + return new Promise(function(resolve, reject) { + setTimeout(resolve, timeout, true) + }) +} function playSound() { const audio = new Audio('/media/ding.mp3') @@ -277,16 +299,40 @@ function playSound() { async function healthCheck() { try { - let res = await fetch('/ping') - res = await res.json() - - if (res[0] == 'OK') { - setStatus('server', 'online', 'success') + let res = undefined + if (sessionId) { + res = await fetch('/ping?session_id=' + sessionId) } else { - setStatus('server', 'offline', 'error') + res = await fetch('/ping') } + serverState = await res.json() + if (typeof serverState !== 'object' || typeof serverState.status !== 'string') { + serverState = {'status': 'Offline', 'time': Date.now()} + setServerStatus('error', 'offline') + return + } + // Set status + switch(serverState.status) { + case 'Init': + // Wait for init to complete before updating status. + break + case 'Online': + setServerStatus('online', 'ready') + break + case 'LoadingModel': + setServerStatus('busy', 'loading model') + break + case 'Rendering': + setServerStatus('busy', 'rendering') + break + default: // Unavailable + setServerStatus('error', serverState.status.toLowerCase()) + break + } + serverState.time = Date.now() } catch (e) { - setStatus('server', 'offline', 'error') + serverState = {'status': 'Offline', 'time': Date.now()} + setServerStatus('error', 'offline') } } function resizeInpaintingEditor() { @@ -329,7 +375,7 @@ function showImages(reqBody, res, outputContainer, livePreview) { if(typeof res != 'object') return res.output.reverse() res.output.forEach((result, index) => { - const imageData = result?.data || result?.path + '?t=' + new Date().getTime(), + const imageData = result?.data || result?.path + '?t=' + Date.now(), imageSeed = result?.seed, imagePrompt = reqBody.prompt, imageInferenceSteps = reqBody.num_inference_steps, @@ -440,8 +486,8 @@ function getSaveImageHandler(imageItemElem, outputFormat) { } function getStartNewTaskHandler(reqBody, imageItemElem, mode) { return function() { - if (serverStatus !== 'online') { - alert('The server is still starting up..') + if (!isServerAvailable()) { + alert('The server is not available.') return } const imageElem = imageItemElem.querySelector('img') @@ -507,37 +553,72 @@ async function doMakeImage(task) { const progressBar = task['progressBar'] let res = undefined - let stepUpdate = undefined try { - res = await fetch('/image', { - method: 'POST', + const lastTask = serverState.task + let renderRequest = undefined + do { + res = await fetch('/render', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(reqBody) + }) + renderRequest = await res.json() + // status_code 503, already a task running. + } while (renderRequest.status_code === 503 && await asyncDelay(30 * 1000)) + if (typeof renderRequest?.stream !== 'string') { + console.log('Endpoint response: ', renderRequest) + throw new Error('Endpoint response does not contains a response stream url.') + } + task['taskStatusLabel'].innerText = "Busy/Waiting" + task['taskStatusLabel'].classList.add('waitingTaskLabel') + task['taskStatusLabel'].classList.remove('activeTaskLabel') + + do { // Wait for server status to update. + await asyncDelay(250) + if (!isServerAvailable()) { + throw new Error('Connexion with server lost.') + } + } while (serverState.time > (Date.now() - (10 * 1000)) && serverState.task !== renderRequest.task) + if (serverState.session !== 'pending' && serverState.session !== 'running' && serverState.session !== 'buffer') { + throw new Error('Unexpected server task state: ' + serverState.session || 'Undefined') + } + while (serverState.task === renderRequest.task && serverState.session === 'pending') { + // Wait for task to start on server. + await asyncDelay(1500) + } + + // Task started! + res = await fetch(renderRequest.stream, { headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(reqBody) }) + task['taskStatusLabel'].innerText = "Processing" + task['taskStatusLabel'].classList.add('activeTaskLabel') + task['taskStatusLabel'].classList.remove('waitingTaskLabel') + + let stepUpdate = undefined let reader = res.body.getReader() let textDecoder = new TextDecoder() let finalJSON = '' let prevTime = -1 let readComplete = false - while (true) { - let t = new Date().getTime() - + while (!readComplete || finalJSON.length > 0) { + let t = Date.now() let jsonStr = '' if (!readComplete) { const {value, done} = await reader.read() if (done) { readComplete = true } - if (done && finalJSON.length <= 0 && !value) { - break - } if (value) { jsonStr = textDecoder.decode(value) } } + stepUpdate = undefined try { // hack for a middleman buffering all the streaming updates, and unleashing them on the poor browser in one shot. // this results in having to parse JSON like {"step": 1}{"step": 2}{"step": 3}{"ste... @@ -571,9 +652,6 @@ async function doMakeImage(task) { throw e } } - if (readComplete && finalJSON.length <= 0) { - break - } if (typeof stepUpdate === 'object' && 'step' in stepUpdate) { let batchSize = stepUpdate.total_steps let overallStepCount = stepUpdate.step + task.batchesDone * batchSize @@ -598,6 +676,23 @@ async function doMakeImage(task) { showImages(reqBody, stepUpdate, outputContainer, true) } } + if (stepUpdate?.status) { + break + } + if (readComplete && finalJSON.length <= 0) { + if (res.status === 200) { + await asyncDelay(5000) + res = await fetch(renderRequest.stream, { + headers: { + 'Content-Type': 'application/json' + }, + }) + reader = res.body.getReader() + readComplete = false + } else { + console.log('Stream stopped: ', res) + } + } prevTime = t } @@ -614,27 +709,28 @@ async function doMakeImage(task) { 3. Try generating a smaller image.
` } } else { - msg = `Unexpected Read Error:
StepUpdate:${JSON.stringify(stepUpdate, undefined, 4)}
` + msg = `Unexpected Read Error:
StepUpdate: ${JSON.stringify(stepUpdate, undefined, 4)}
` } logError(msg, res, outputMsg) return false } if (typeof stepUpdate !== 'object' || !res || res.status != 200) { - if (serverStatus !== 'online') { + if (!isServerAvailable()) { logError("Stable Diffusion is still starting up, please wait. If this goes on beyond a few minutes, Stable Diffusion has probably crashed. Please check the error message in the command-line window.", res, outputMsg) } else if (typeof res === 'object') { let msg = 'Stable Diffusion had an error reading the response: ' try { // 'Response': body stream already read msg += 'Read: ' + await res.text() } catch(e) { - msg += 'No error response. ' + msg += 'Unexpected end of stream. ' } if (finalJSON) { msg += 'Buffered data: ' + finalJSON } logError(msg, res, outputMsg) } else { - msg = `Unexpected Read Error:
Response:${res}
StepUpdate:${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}
` + let msg = `Unexpected Read Error:
Response: ${res}
StepUpdate: ${typeof stepUpdate === 'object' ? JSON.stringify(stepUpdate, undefined, 4) : stepUpdate}
` + logError(msg, res, outputMsg) } progressBar.style.display = 'none' return false @@ -682,14 +778,14 @@ async function checkTasks() { let task = taskQueue.pop() currentTask = task - let time = new Date().getTime() + let time = Date.now() let successCount = 0 task.isProcessing = true task['stopTask'].innerHTML = ' Stop' - task['taskStatusLabel'].innerText = "Processing" - task['taskStatusLabel'].className += " activeTaskLabel" + task['taskStatusLabel'].innerText = "Starting" + task['taskStatusLabel'].classList.add('waitingTaskLabel') const genSeeds = Boolean(typeof task.reqBody.seed !== 'number' || (task.reqBody.seed === task.seed && task.numOutputsTotal > 1)) const startSeed = task.reqBody.seed || task.seed @@ -724,7 +820,7 @@ async function checkTasks() { task['stopTask'].innerHTML = ' Remove' task['taskStatusLabel'].style.display = 'none' - time = new Date().getTime() - time + time = Date.now() - time time /= 1000 if (successCount === task.batchCount) { @@ -814,8 +910,8 @@ function getCurrentUserRequest() { } function makeImage() { - if (serverStatus !== 'online') { - alert('The server is still starting up..') + if (!isServerAvailable()) { + alert('The server is not available.') return } const taskTemplate = getCurrentUserRequest() @@ -868,7 +964,7 @@ function createTask(task) { if (task['isProcessing']) { task.isProcessing = false try { - let res = await fetch('/image/stop') + let res = await fetch('/image/stop?session_id=' + sessionId) } catch (e) { console.log(e) } @@ -1135,9 +1231,9 @@ promptStrengthField.addEventListener('input', updatePromptStrengthSlider) updatePromptStrength() useBetaChannelField.addEventListener('click', async function(e) { - if (serverStatus !== 'online') { + if (!isServerAvailable()) { // logError('The server is still starting up..') - alert('The server is still starting up..') + alert('The server is not available.') e.preventDefault() return false } @@ -1164,7 +1260,7 @@ useBetaChannelField.addEventListener('click', async function(e) { async function getAppConfig() { try { - let res = await fetch('/app_config') + let res = await fetch('/get/app_config') const config = await res.json() if (config.update_branch === 'beta') { @@ -1180,7 +1276,7 @@ async function getAppConfig() { async function getModels() { try { - let res = await fetch('/models') + let res = await fetch('/get/models') const models = await res.json() let activeModel = models['active'] @@ -1451,10 +1547,10 @@ async function getDiskPath() { return } - let res = await fetch('/output_dir') + let res = await fetch('/get/output_dir') if (res.status === 200) { res = await res.json() - res = res[0] + res = res.output_dir document.querySelector('#diskPath').value = res } @@ -1562,14 +1658,15 @@ function resizeModifierCards(val) { const classes = card.className.split(' ').filter(c => !c.startsWith(cardSizePrefix)) card.className = classes.join(' ').trim() - if(val != 0) + if(val != 0) { card.classList.add(cardSize(val)) + } }) } async function loadModifiers() { try { - let res = await fetch('/modifiers.json?v=2') + let res = await fetch('/get/modifiers') if (res.status === 200) { res = await res.json() diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 63506579..0b0a3003 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -197,6 +197,35 @@ def load_model_real_esrgan(real_esrgan_to_use): print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision) +def get_base_path(disk_path, session_id, prompt, ext, suffix=None): + if disk_path is None: return None + if session_id is None: return None + if ext is None: raise Exception('Missing ext') + + session_out_path = os.path.join(disk_path, session_id) + os.makedirs(session_out_path, exist_ok=True) + + prompt_flattened = filename_regex.sub('_', prompt)[:50] + img_id = str(uuid.uuid4())[-8:] + + if suffix is not None: + return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") + return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") + +def apply_filters(filter_name, image_data): + print(f'Applying filter {filter_name}...') + gc() + + if filter_name == 'gfpgan': + _, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + image_data = output[:,:,::-1] + + if filter_name == 'real_esrgan': + output, _ = model_real_esrgan.enhance(image_data[:,:,::-1]) + image_data = output[:,:,::-1] + + return image_data + def mk_img(req: Request): try: yield from do_mk_img(req) @@ -283,23 +312,11 @@ def do_mk_img(req: Request): opt_prompt = req.prompt opt_seed = req.seed - opt_n_samples = req.num_outputs opt_n_iter = 1 - opt_scale = req.guidance_scale opt_C = 4 - opt_H = req.height - opt_W = req.width opt_f = 8 - opt_ddim_steps = req.num_inference_steps opt_ddim_eta = 0.0 - opt_strength = req.prompt_strength - opt_save_to_disk_path = req.save_to_disk_path opt_init_img = req.init_image - opt_use_face_correction = req.use_face_correction - opt_use_upscale = req.use_upscale - opt_show_only_filtered = req.show_only_filtered_image - opt_format = req.output_format - opt_sampler_name = req.sampler print(req.to_string(), '\n device', device) @@ -307,7 +324,7 @@ def do_mk_img(req: Request): seed_everything(opt_seed) - batch_size = opt_n_samples + batch_size = req.num_outputs prompt = opt_prompt assert prompt is not None data = [batch_size * [prompt]] @@ -327,7 +344,7 @@ def do_mk_img(req: Request): else: handler = _img2img - init_image = load_img(req.init_image, opt_W, opt_H) + init_image = load_img(req.init_image, req.width, req.height) init_image = init_image.to(device) if device != "cpu" and precision == "autocast": @@ -339,7 +356,7 @@ def do_mk_img(req: Request): init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space if req.mask is not None: - mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device) + mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device) mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) mask = repeat(mask, '1 ... -> b ...', b=batch_size) @@ -348,12 +365,12 @@ def do_mk_img(req: Request): move_fs_to_cpu() - assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(opt_strength * opt_ddim_steps) + assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(req.prompt_strength * req.num_inference_steps) print(f"target t_enc is {t_enc} steps") - if opt_save_to_disk_path is not None: - session_out_path = os.path.join(opt_save_to_disk_path, req.session_id) + if req.save_to_disk_path is not None: + session_out_path = os.path.join(req.save_to_disk_path, req.session_id) os.makedirs(session_out_path, exist_ok=True) else: session_out_path = None @@ -366,7 +383,7 @@ def do_mk_img(req: Request): with precision_scope("cuda"): modelCS.to(device) uc = None - if opt_scale != 1.0: + if req.guidance_scale != 1.0: uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -393,7 +410,7 @@ def do_mk_img(req: Request): partial_x_samples = x_samples if req.stream_progress_updates: - n_steps = opt_ddim_steps if req.init_image is None else t_enc + n_steps = req.num_inference_steps if req.init_image is None else t_enc progress = {"step": i, "total_steps": n_steps} if req.stream_image_progress and i % 5 == 0: @@ -425,9 +442,9 @@ def do_mk_img(req: Request): # run the handler try: if handler == _txt2img: - x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name) + x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) else: - x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask) + x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) yield from x_samples @@ -447,69 +464,49 @@ def do_mk_img(req: Request): x_sample = x_sample.astype(np.uint8) img = Image.fromarray(x_sample) - has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \ - (opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')) + has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ + (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) - return_orig_img = not has_filters or not opt_show_only_filtered + return_orig_img = not has_filters or not req.show_only_filtered_image if stop_processing: return_orig_img = True - if opt_save_to_disk_path is not None: - prompt_flattened = filename_regex.sub('_', prompts[0]) - prompt_flattened = prompt_flattened[:50] - - img_id = str(uuid.uuid4())[-8:] - - file_path = f"{prompt_flattened}_{img_id}" - img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}") - meta_out_path = os.path.join(session_out_path, f"{file_path}.txt") - + if req.save_to_disk_path is not None: if return_orig_img: + img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format) save_image(img, img_out_path) - - save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file) + meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], 'txt') + save_metadata(meta_out_path, req, prompts[0], opt_seed) if return_orig_img: - img_data = img_to_base64_str(img, opt_format) + img_data = img_to_base64_str(img, req.output_format) res_image_orig = ResponseImage(data=img_data, seed=opt_seed) res.images.append(res_image_orig) - if opt_save_to_disk_path is not None: + if req.save_to_disk_path is not None: res_image_orig.path_abs = img_out_path del img if has_filters and not stop_processing: - print('Applying filters..') - - gc() filters_applied = [] - - if opt_use_face_correction: - _, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - x_sample = output[:,:,::-1] - filters_applied.append(opt_use_face_correction) - - if opt_use_upscale: - output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1]) - x_sample = output[:,:,::-1] - filters_applied.append(opt_use_upscale) - - filtered_image = Image.fromarray(x_sample) - - filtered_img_data = img_to_base64_str(filtered_image, opt_format) - res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed) - res.images.append(res_image_filtered) - - filters_applied = "_".join(filters_applied) - - if opt_save_to_disk_path is not None: - filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}") - save_image(filtered_image, filtered_img_out_path) - res_image_filtered.path_abs = filtered_img_out_path - - del filtered_image + if req.use_face_correction: + x_sample = apply_filters('gfpgan', x_sample) + filters_applied.append(req.use_face_correction) + if req.use_upscale: + x_sample = apply_filters('real_esrgan', x_sample) + filters_applied.append(req.use_upscale) + if (len(filters_applied) > 0): + filtered_image = Image.fromarray(x_sample) + filtered_img_data = img_to_base64_str(filtered_image, req.output_format) + response_image = ResponseImage(data=filtered_img_data, seed=req.seed) + res.images.append(response_image) + if req.save_to_disk_path is not None: + filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format, "_".join(filters_applied)) + save_image(filtered_image, filtered_img_out_path) + response_image.path_abs = filtered_img_out_path + del filtered_image seeds += str(opt_seed) + "," opt_seed += 1 @@ -529,9 +526,20 @@ def save_image(img, img_out_path): except: print('could not save the file', traceback.format_exc()) -def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file): - metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}" - +def save_metadata(meta_out_path, req, prompt, opt_seed): + metadata = f"""{prompt} +Width: {req.width} +Height: {req.height} +Seed: {opt_seed} +Steps: {req.num_inference_steps} +Guidance Scale: {req.guidance_scale} +Prompt Strength: {req.prompt_strength} +Use Face Correction: {req.use_face_correction} +Use Upscaling: {req.use_upscale} +Sampler: {req.sampler} +Negative Prompt: {req.negative_prompt} +Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'} +""" try: with open(meta_out_path, 'w') as f: f.write(metadata) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py new file mode 100644 index 00000000..00364e0c --- /dev/null +++ b/ui/sd_internal/task_manager.py @@ -0,0 +1,298 @@ +import json +import traceback + +TASK_TTL = 15 * 60 # Discard last session's task timeout + +import queue, threading, time +from typing import Any, Generator, Hashable, Optional, Union + +from pydantic import BaseModel +from sd_internal import Request, Response + +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): return self.__qualname__ + def __str__(self): return self.__name__ +class Symbol(metaclass=SymbolClass): pass + +class ServerStates: + class Init(Symbol): pass + class LoadingModel(Symbol): pass + class Online(Symbol): pass + class Rendering(Symbol): pass + class Unavailable(Symbol): pass + +class RenderTask(): # Task with output queue and completion lock. + def __init__(self, req: Request): + self.request: Request = req # Initial Request + self.response: Any = None # Copy of the last reponse + self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) + self.error: Exception = None + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + async def read_buffer_generator(self): + try: + while not self.buffer_queue.empty(): + res = self.buffer_queue.get(block=False) + self.buffer_queue.task_done() + yield res + except queue.Empty as e: yield + +# defaults from https://huggingface.co/blog/stable_diffusion +class ImageRequest(BaseModel): + session_id: str = "session" + prompt: str = "" + negative_prompt: str = "" + init_image: str = None # base64 + mask: str = None # base64 + num_outputs: int = 1 + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + width: int = 512 + height: int = 512 + seed: int = 42 + prompt_strength: float = 0.8 + sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + # allow_nsfw: bool = False + save_to_disk_path: str = None + turbo: bool = True + use_cpu: bool = False + use_full_precision: bool = False + use_face_correction: str = None # or "GFPGANv1.3" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + use_stable_diffusion_model: str = "sd-v1-4" + show_only_filtered_image: bool = False + output_format: str = "jpeg" # or "png" + + stream_progress_updates: bool = False + stream_image_progress: bool = False + +# Temporary cache to allow to query tasks results for a short time after they are completed. +class TaskCache(): + def __init__(self): + self._base = dict() + self._lock: threading.Lock = threading.RLock() + def _get_ttl_time(self, ttl: int) -> int: + return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: + return int(time.time()) >= timestamp + def clean(self) -> None: + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.') + try: + # Create a list of expired keys to delete + to_delete = [] + for key in self._base: + ttl, _ = self._base[key] + if self._is_expired(ttl): + to_delete.append(key) + # Remove Items + for key in to_delete: + del self._base[key] + print(f'Session {key} expired. Data removed.') + finally: + self._lock.release() + def clear(self) -> None: + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.') + try: self._base.clear() + finally: self._lock.release() + def delete(self, key: Hashable) -> bool: + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.') + try: + if key not in self._base: + return False + del self._base[key] + return True + finally: + self._lock.release() + def keep(self, key: Hashable, ttl: int) -> bool: + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.') + try: + if key in self._base: + _, value = self._base.get(key) + self._base[key] = (self._get_ttl_time(ttl), value) + return True + return False + finally: + self._lock.release() + def put(self, key: Hashable, value: Any, ttl: int) -> bool: + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.') + try: + self._base[key] = ( + self._get_ttl_time(ttl), value + ) + except Exception as e: + print(str(e)) + print(traceback.format_exc()) + return False + else: + return True + finally: + self._lock.release() + def tryGet(self, key: Hashable) -> Any: + if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.') + try: + ttl, value = self._base.get(key, (None, None)) + if ttl is not None and self._is_expired(ttl): + print(f'Session {key} expired. Discarding data.') + self.delete(key) + return None + return value + finally: + self._lock.release() + +current_state = ServerStates.Init +current_state_error:Exception = None +current_model_path = None +tasks_queue = queue.Queue() +task_cache = TaskCache() +default_model_to_load = None + +def preload_model(file_path=None): + global current_state, current_state_error, current_model_path + if file_path == None: + file_path = default_model_to_load + if file_path == current_model_path: + return + current_state = ServerStates.LoadingModel + try: + from . import runtime + runtime.load_model_ckpt(ckpt_to_use=file_path) + current_model_path = file_path + current_state_error = None + current_state = ServerStates.Online + except Exception as e: + current_model_path = None + current_state_error = e + current_state = ServerStates.Unavailable + print(traceback.format_exc()) + +def thread_render(): + global current_state, current_state_error, current_model_path + from . import runtime + current_state = ServerStates.Online + preload_model() + while True: + task_cache.clean() + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + task = None + try: + task = tasks_queue.get(timeout=1) + except queue.Empty as e: + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + else: continue + #if current_model_path != task.request.use_stable_diffusion_model: + # preload_model(task.request.use_stable_diffusion_model) + if current_state_error: + task.error = current_state_error + continue + print(f'Session {task.request.session_id} starting task {id(task)}') + try: + task.lock.acquire(blocking=False) + res = runtime.mk_img(task.request) + if current_model_path == task.request.use_stable_diffusion_model: + current_state = ServerStates.Rendering + else: + current_state = ServerStates.LoadingModel + except Exception as e: + task.error = e + task.lock.release() + tasks_queue.task_done() + print(traceback.format_exc()) + continue + dataQueue = None + if task.request.stream_progress_updates: + dataQueue = task.buffer_queue + for result in res: + if current_state == ServerStates.LoadingModel: + current_state = ServerStates.Rendering + current_model_path = task.request.use_stable_diffusion_model + if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + runtime.stop_processing = True + if isinstance(current_state_error, StopAsyncIteration): + task.error = current_state_error + current_state_error = None + print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + if dataQueue: + dataQueue.put(result) + if isinstance(result, str): + result = json.loads(result) + task.response = result + if 'output' in result: + for out_obj in result['output']: + if 'path' in out_obj: + img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] + task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] + elif 'data' in out_obj: + task.temp_images[result['output'].index(out_obj)] = out_obj['data'] + task_cache.keep(task.request.session_id, TASK_TTL) + # Task completed + task.lock.release() + tasks_queue.task_done() + task_cache.keep(task.request.session_id, TASK_TTL) + if isinstance(task.error, StopAsyncIteration): + print(f'Session {task.request.session_id} task {id(task)} cancelled!') + elif task.error is not None: + print(f'Session {task.request.session_id} task {id(task)} failed!') + else: + print(f'Session {task.request.session_id} task {id(task)} completed.') + current_state = ServerStates.Online + +render_thread = threading.Thread(target=thread_render) + +def start_render_thread(): + # Start Rendering Thread + render_thread.daemon = True + render_thread.start() + +def shutdown_event(): # Signal render thread to close on shutdown + global current_state_error + current_state_error = SystemExit('Application shutting down.') + +def render(req : ImageRequest): + if not render_thread.is_alive(): # Render thread is dead + raise ChildProcessError('Rendering thread has died.') + # Alive, check if task in cache + task = task_cache.tryGet(req.session_id) + if task and not task.response and not task.error and not task.lock.locked(): + # Unstarted task pending, deny queueing more than one. + raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') + # + from . import runtime + r = Request() + r.session_id = req.session_id + r.prompt = req.prompt + r.negative_prompt = req.negative_prompt + r.init_image = req.init_image + r.mask = req.mask + r.num_outputs = req.num_outputs + r.num_inference_steps = req.num_inference_steps + r.guidance_scale = req.guidance_scale + r.width = req.width + r.height = req.height + r.seed = req.seed + r.prompt_strength = req.prompt_strength + r.sampler = req.sampler + # r.allow_nsfw = req.allow_nsfw + r.turbo = req.turbo + r.use_cpu = req.use_cpu + r.use_full_precision = req.use_full_precision + r.save_to_disk_path = req.save_to_disk_path + r.use_upscale: str = req.use_upscale + r.use_face_correction = req.use_face_correction + r.show_only_filtered_image = req.show_only_filtered_image + r.output_format = req.output_format + + r.stream_progress_updates = True # the underlying implementation only supports streaming + r.stream_image_progress = req.stream_image_progress + + if not req.stream_progress_updates: + r.stream_image_progress = False + + new_task = RenderTask(r) + if task_cache.put(r.session_id, new_task, TASK_TTL): + tasks_queue.put(new_task, block=True, timeout=30) + return new_task + raise RuntimeError('Failed to add task to cache.') diff --git a/ui/server.py b/ui/server.py index 5d7f1bfd..a803ceb9 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,90 +14,32 @@ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +TASK_TTL = 15 * 60 # Discard last session's task timeout from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles -from starlette.responses import FileResponse, StreamingResponse +from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel import logging +import queue, threading, time +from typing import Any, Generator, Hashable, Optional, Union -from sd_internal import Request, Response +from sd_internal import Request, Response, task_manager app = FastAPI() -model_loaded = False -model_is_loading = False - modifiers_cache = None outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) # don't show access log entries for URLs that start with the given prefix -ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails'] +ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] +NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") -# defaults from https://huggingface.co/blog/stable_diffusion -class ImageRequest(BaseModel): - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - save_to_disk_path: str = None - turbo: bool = True - use_cpu: bool = False - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - - stream_progress_updates: bool = False - stream_image_progress: bool = False - class SetAppConfigRequest(BaseModel): update_branch: str = "main" -@app.get('/') -def read_root(): - headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers) - -@app.get('/ping') -async def ping(): - global model_loaded, model_is_loading - - try: - if model_loaded: - return {'OK'} - - if model_is_loading: - return {'ERROR'} - - model_is_loading = True - - from sd_internal import runtime - - runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load()) - - model_loaded = True - model_is_loading = False - - return {'OK'} - except Exception as e: - print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) - # needs to support the legacy installations def get_initial_model_to_load(): custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') @@ -114,7 +56,6 @@ def get_initial_model_to_load(): ckpt_to_use = model_path else: print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt') - return ckpt_to_use def resolve_model_to_use(model_name): @@ -126,92 +67,110 @@ def resolve_model_to_use(model_name): model_path = legacy_model_path else: model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) - return model_path +@app.on_event("shutdown") +def shutdown_event(): # Signal render thread to close on shutdown + task_manager.current_state_error = SystemExit('Application shutting down.') + +@app.get('/') +def read_root(): + return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) + +@app.get('/ping') # Get server and optionally session status. +def ping(session_id:str=None): + if not task_manager.render_thread.is_alive(): # Render thread is dead. + if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) + return HTTPException(status_code=500, detail='Render thread is dead.') + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) + # Alive + response = {'status': str(task_manager.current_state)} + if session_id: + task = task_manager.task_cache.tryGet(session_id) + if task: + response['task'] = id(task) + if task.lock.locked(): + response['session'] = 'running' + elif isinstance(task.error, StopAsyncIteration): + response['session'] = 'stopped' + elif task.error: + response['session'] = 'error' + elif not task.buffer_queue.empty(): + response['session'] = 'buffer' + elif task.response: + response['session'] = 'completed' + else: + response['session'] = 'pending' + return JSONResponse(response, headers=NOCACHE_HEADERS) + def save_model_to_config(model_name): config = getConfig() if 'model' not in config: config['model'] = {} config['model']['stable-diffusion'] = model_name - setConfig(config) -@app.post('/image') -def image(req : ImageRequest): - from sd_internal import runtime - - r = Request() - r.session_id = req.session_id - r.prompt = req.prompt - r.negative_prompt = req.negative_prompt - r.init_image = req.init_image - r.mask = req.mask - r.num_outputs = req.num_outputs - r.num_inference_steps = req.num_inference_steps - r.guidance_scale = req.guidance_scale - r.width = req.width - r.height = req.height - r.seed = req.seed - r.prompt_strength = req.prompt_strength - r.sampler = req.sampler - # r.allow_nsfw = req.allow_nsfw - r.turbo = req.turbo - r.use_cpu = req.use_cpu - r.use_full_precision = req.use_full_precision - r.save_to_disk_path = req.save_to_disk_path - r.use_upscale: str = req.use_upscale - r.use_face_correction = req.use_face_correction - r.show_only_filtered_image = req.show_only_filtered_image - r.output_format = req.output_format - - r.stream_progress_updates = True # the underlying implementation only supports streaming - r.stream_image_progress = req.stream_image_progress - - r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) - - save_model_to_config(req.use_stable_diffusion_model) - +@app.post('/render') +def render(req : task_manager.ImageRequest): try: - if not req.stream_progress_updates: - r.stream_image_progress = False - - res = runtime.mk_img(r) - - if req.stream_progress_updates: - return StreamingResponse(res, media_type='application/json') - else: # compatibility mode: buffer the streaming responses, and return the last one - last_result = None - - for result in res: - last_result = result - - return json.loads(last_result) + save_model_to_config(req.use_stable_diffusion_model) + req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) + new_task = task_manager.render(req) + response = { + 'status': str(task_manager.current_state), + 'queue': task_manager.tasks_queue.qsize(), + 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', + 'task': id(new_task) + } + return JSONResponse(response, headers=NOCACHE_HEADERS) + except ChildProcessError as e: # Render thread is dead + return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error + except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one. + return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable except Exception as e: - print(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e)) +@app.get('/image/stream/{session_id:str}/{task_id:int}') +def stream(session_id:str, task_id:int): + #TODO Move to WebSockets ?? + task = task_manager.task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone + if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + if task.buffer_queue.empty() and not task.lock.locked(): + if task.response: + #print(f'Session {session_id} sending cached response') + return JSONResponse(task.response, headers=NOCACHE_HEADERS) + return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early + #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + return StreamingResponse(task.read_buffer_generator(), media_type='application/json') + @app.get('/image/stop') -def stop(): - try: - if model_is_loading: - return {'ERROR'} - - from sd_internal import runtime - runtime.stop_processing = True - +def stop(session_id:str=None): + if not session_id: + if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: + return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict + task_manager.current_state_error = StopAsyncIteration('') return {'OK'} - except Exception as e: - print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) + task = task_manager.task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict + task.error = StopAsyncIteration('') + return {'OK'} -@app.get('/image/tmp/{session_id}/{img_id}') +@app.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): - from sd_internal import runtime - buf = runtime.temp_images[session_id + '/' + img_id] - buf.seek(0) - return StreamingResponse(buf, media_type='image/jpeg') + task = task_manager.task_cache.tryGet(session_id) + if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone + if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early + try: + img_data = task.temp_images[img_id] + if isinstance(img_data, str): + return img_data + img_data.seek(0) + return StreamingResponse(img_data, media_type='image/jpeg') + except KeyError as e: + return HTTPException(status_code=500, detail=str(e)) @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): @@ -242,42 +201,27 @@ async def setAppConfig(req : SetAppConfigRequest): print(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e)) -@app.get('/app_config') -def getAppConfig(): +def getConfig(default_val={}): try: config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return HTTPException(status_code=500, detail="No config file") - + return default_val with open(config_json_path, 'r') as f: return json.load(f) except Exception as e: + print(str(e)) print(traceback.format_exc()) - return HTTPException(status_code=500, detail=str(e)) - -def getConfig(): - try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - - if not os.path.exists(config_json_path): - return {} - - with open(config_json_path, 'r') as f: - return json.load(f) - except Exception as e: - return {} + return default_val def setConfig(config): try: config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w') as f: return json.dump(config, f) except: + print(str(e)) print(traceback.format_exc()) -@app.get('/models') def getModels(): models = { 'active': { @@ -307,14 +251,21 @@ def getModels(): return models -@app.get('/modifiers.json') -def read_modifiers(): - headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=headers) - -@app.get('/output_dir') -def read_home_dir(): - return {outpath} +@app.get('/get/{key:path}') +def read_web_data(key:str=None): + if not key: # /get without parameters, stable-diffusion easter egg. + return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot + elif key == 'app_config': + config = getConfig(default_val=None) + if config is None: + return HTTPException(status_code=500, detail="Config file is missing or unreadable") + return JSONResponse(config, headers=NOCACHE_HEADERS) + elif key == 'models': + return JSONResponse(getModels(), headers=NOCACHE_HEADERS) + elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) + elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS) + else: + return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found # don't log certain requests class LogSuppressFilter(logging.Filter): @@ -323,10 +274,11 @@ class LogSuppressFilter(logging.Filter): for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: if path.find(prefix) != -1: return False - return True - logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) +task_manager.default_model_to_load = get_initial_model_to_load() +task_manager.start_render_thread() + # start the browser ui import webbrowser; webbrowser.open('http://localhost:9000') \ No newline at end of file