From 7de699c7fa5e2782466e42a284666374d815b494 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 15 Oct 2022 03:28:20 -0400 Subject: [PATCH] Moved a lot of code into task_manager.py --- ui/sd_internal/task_manager.py | 288 ++++++++++++++++++++++++++++ ui/server.py | 332 +++++---------------------------- 2 files changed, 334 insertions(+), 286 deletions(-) create mode 100644 ui/sd_internal/task_manager.py diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py new file mode 100644 index 00000000..67b0bde0 --- /dev/null +++ b/ui/sd_internal/task_manager.py @@ -0,0 +1,288 @@ +import json +import traceback + +TASK_TTL = 15 * 60 # Discard last session's task timeout + +import queue, threading, time +from typing import Any, Generator, Hashable, Optional, Union + +from pydantic import BaseModel +from sd_internal import Request, Response + +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): return self.__qualname__ + def __str__(self): return self.__name__ +class Symbol(metaclass=SymbolClass): pass + +class ServerStates: + class Init(Symbol): pass + class LoadingModel(Symbol): pass + class Online(Symbol): pass + class Rendering(Symbol): pass + class Unavailable(Symbol): pass + +class RenderTask(): # Task with output queue and completion lock. + def __init__(self, req: Request): + self.request: Request = req # Initial Request + self.response: Any = None # Copy of the last reponse + self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) + self.error: Exception = None + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + async def read_buffer_generator(self): + try: + while not self.buffer_queue.empty(): + res = self.buffer_queue.get(block=False) + self.buffer_queue.task_done() + yield res + except queue.Empty as e: yield + +# defaults from https://huggingface.co/blog/stable_diffusion +class ImageRequest(BaseModel): + session_id: str = "session" + prompt: str = "" + negative_prompt: str = "" + init_image: str = None # base64 + mask: str = None # base64 + num_outputs: int = 1 + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + width: int = 512 + height: int = 512 + seed: int = 42 + prompt_strength: float = 0.8 + sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + # allow_nsfw: bool = False + save_to_disk_path: str = None + turbo: bool = True + use_cpu: bool = False + use_full_precision: bool = False + use_face_correction: str = None # or "GFPGANv1.3" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + use_stable_diffusion_model: str = "sd-v1-4" + show_only_filtered_image: bool = False + output_format: str = "jpeg" # or "png" + + stream_progress_updates: bool = False + stream_image_progress: bool = False + +# Temporary cache to allow to query tasks results for a short time after they are completed. +class TaskCache(): + def __init__(self): + self._base = dict() + self._lock: threading.Lock = threading.Lock() + def _get_ttl_time(self, ttl: int) -> int: + return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: + return int(time.time()) >= timestamp + def clean(self) -> None: + self._lock.acquire() + try: + for key in self._base: + ttl, _ = self._base[key] + if self._is_expired(ttl): + del self._base[key] + finally: + self._lock.release() + def clear(self) -> None: + self._lock.acquire() + try: self._base.clear() + finally: self._lock.release() + def delete(self, key: Hashable) -> bool: + self._lock.acquire() + try: + if key not in self._base: + return False + del self._base[key] + return True + finally: + self._lock.release() + def keep(self, key: Hashable, ttl: int) -> bool: + self._lock.acquire() + try: + if key in self._base: + _, value = self._base.get(key) + self._base[key] = (self._get_ttl_time(ttl), value) + return True + return False + finally: + self._lock.release() + def put(self, key: Hashable, value: Any, ttl: int) -> bool: + self._lock.acquire() + try: + self._base[key] = ( + self._get_ttl_time(ttl), value + ) + except Exception: + return False + else: + return True + finally: + self._lock.release() + def tryGet(self, key: Hashable) -> Any: + self._lock.acquire() + try: + ttl, value = self._base.get(key, (None, None)) + if ttl is not None and self._is_expired(ttl): + self.delete(key) + return None + return value + finally: + self._lock.release() + +current_state = ServerStates.Init +current_state_error:Exception = None +current_model_path = None +tasks_queue = queue.Queue() +task_cache = TaskCache() +default_model_to_load = None + +def preload_model(file_path=None): + global current_state, current_state_error, current_model_path + if file_path == None: + file_path = default_model_to_load + if file_path == current_model_path: + return + current_state = ServerStates.LoadingModel + try: + from . import runtime + runtime.load_model_ckpt(ckpt_to_use=file_path) + current_model_path = file_path + current_state_error = None + current_state = ServerStates.Online + except Exception as e: + current_model_path = None + current_state_error = e + current_state = ServerStates.Unavailable + print(traceback.format_exc()) + +def thread_render(): + global current_state, current_state_error, current_model_path + from . import runtime + current_state = ServerStates.Online + preload_model() + while True: + task_cache.clean() + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + task = None + try: + task = tasks_queue.get(timeout=1) + except queue.Empty as e: + if isinstance(current_state_error, SystemExit): + current_state = ServerStates.Unavailable + return + else: continue + #if current_model_path != task.request.use_stable_diffusion_model: + # preload_model(task.request.use_stable_diffusion_model) + if current_state_error: + task.error = current_state_error + continue + print(f'Session {task.request.session_id} starting task {id(task)}') + try: + task.lock.acquire(blocking=False) + res = runtime.mk_img(task.request) + if current_model_path == task.request.use_stable_diffusion_model: + current_state = ServerStates.Rendering + else: + current_state = ServerStates.LoadingModel + except Exception as e: + task.error = e + task.lock.release() + tasks_queue.task_done() + print(traceback.format_exc()) + continue + dataQueue = None + if task.request.stream_progress_updates: + dataQueue = task.buffer_queue + for result in res: + if current_state == ServerStates.LoadingModel: + current_state = ServerStates.Rendering + current_model_path = task.request.use_stable_diffusion_model + if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + runtime.stop_processing = True + if isinstance(current_state_error, StopAsyncIteration): + task.error = current_state_error + current_state_error = None + print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + if dataQueue: + dataQueue.put(result) + if isinstance(result, str): + result = json.loads(result) + task.response = result + if 'output' in result: + for out_obj in result['output']: + if 'path' in out_obj: + img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] + task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] + elif 'data' in out_obj: + task.temp_images[result['output'].index(out_obj)] = out_obj['data'] + task_cache.keep(task.request.session_id, TASK_TTL) + # Task completed + task.lock.release() + tasks_queue.task_done() + task_cache.keep(task.request.session_id, TASK_TTL) + if isinstance(task.error, StopAsyncIteration): + print(f'Session {task.request.session_id} task {id(task)} cancelled!') + elif task.error is not None: + print(f'Session {task.request.session_id} task {id(task)} failed!') + else: + print(f'Session {task.request.session_id} task {id(task)} completed.') + current_state = ServerStates.Online + +render_thread = threading.Thread(target=thread_render) + +def start_render_thread(): + # Start Rendering Thread + render_thread.daemon = True + render_thread.start() + +def shutdown_event(): # Signal render thread to close on shutdown + global current_state_error + current_state_error = SystemExit('Application shutting down.') + +def render(req : ImageRequest): + if not render_thread.is_alive(): # Render thread is dead + raise ChildProcessError('Rendering thread has died.') + # Alive, check if task in cache + task = task_cache.tryGet(req.session_id) + if task and not task.response and not task.error and not task.lock.locked(): + # Unstarted task pending, deny queueing more than one. + raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') + # + from . import runtime + r = Request() + r.session_id = req.session_id + r.prompt = req.prompt + r.negative_prompt = req.negative_prompt + r.init_image = req.init_image + r.mask = req.mask + r.num_outputs = req.num_outputs + r.num_inference_steps = req.num_inference_steps + r.guidance_scale = req.guidance_scale + r.width = req.width + r.height = req.height + r.seed = req.seed + r.prompt_strength = req.prompt_strength + r.sampler = req.sampler + # r.allow_nsfw = req.allow_nsfw + r.turbo = req.turbo + r.use_cpu = req.use_cpu + r.use_full_precision = req.use_full_precision + r.save_to_disk_path = req.save_to_disk_path + r.use_upscale: str = req.use_upscale + r.use_face_correction = req.use_face_correction + r.show_only_filtered_image = req.show_only_filtered_image + r.output_format = req.output_format + + r.stream_progress_updates = True # the underlying implementation only supports streaming + r.stream_image_progress = req.stream_image_progress + + if not req.stream_progress_updates: + r.stream_image_progress = False + + new_task = RenderTask(r) + task_cache.put(r.session_id, new_task, TASK_TTL) + tasks_queue.put(new_task) + return new_task diff --git a/ui/server.py b/ui/server.py index 7522f234..478b9404 100644 --- a/ui/server.py +++ b/ui/server.py @@ -24,7 +24,7 @@ import logging import queue, threading, time from typing import Any, Generator, Hashable, Optional, Union -from sd_internal import Request, Response +from sd_internal import Request, Response, task_manager app = FastAPI() @@ -37,214 +37,9 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") -class SymbolClass(type): # Print nicely formatted Symbol names. - def __repr__(self): return self.__qualname__ - def __str__(self): return self.__name__ -class Symbol(metaclass=SymbolClass): pass - -class ServerStates: - class Init(Symbol): pass - class LoadingModel(Symbol): pass - class Online(Symbol): pass - class Rendering(Symbol): pass - class Unavailable(Symbol): pass - -class RenderTask(): # Task with output queue and completion lock. - def __init__(self, req: Request): - self.request: Request = req # Initial Request - self.response: Any = None # Copy of the last reponse - self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) - self.error: Exception = None - self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed - self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments - -current_state = ServerStates.Init -current_state_error:Exception = None -current_model_path = None -tasks_queue = queue.Queue() - -# defaults from https://huggingface.co/blog/stable_diffusion -class ImageRequest(BaseModel): - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - save_to_disk_path: str = None - turbo: bool = True - use_cpu: bool = False - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - - stream_progress_updates: bool = False - stream_image_progress: bool = False - -# Temporary cache to allow to query tasks results for a short time after they are completed. -class TaskCache(): - def __init__(self): - self._base = dict() - def _get_ttl_time(self, ttl: int) -> int: - return int(time.time()) + ttl - def _is_expired(self, timestamp: int) -> bool: - return int(time.time()) >= timestamp - def clean(self) -> None: - for key in self._base: - ttl, _ = self._base[key] - if self._is_expired(ttl): - del self._base[key] - def clear(self) -> None: - self._base.clear() - def delete(self, key: Hashable) -> bool: - if key not in self._base: - return False - del self._base[key] - return True - def keep(self, key: Hashable, ttl: int) -> bool: - if key in self._base: - _, value = self._base.get(key) - self._base[key] = (self._get_ttl_time(ttl), value) - return True - return False - def put(self, key: Hashable, value: Any, ttl: int) -> bool: - try: - self._base[key] = ( - self._get_ttl_time(ttl), value - ) - except Exception: - return False - return True - def tryGet(self, key: Hashable) -> Any: - ttl, value = self._base.get(key, (None, None)) - if ttl is not None and self._is_expired(ttl): - self.delete(key) - return None - return value - -task_cache = TaskCache() - class SetAppConfigRequest(BaseModel): update_branch: str = "main" -@app.get('/') -def read_root(): - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) - -def preload_model(file_path=None): - global current_state, current_state_error, current_model_path - if file_path == None: - file_path = get_initial_model_to_load() - if file_path == current_model_path: - return - current_state = ServerStates.LoadingModel - try: - from sd_internal import runtime - runtime.load_model_ckpt(ckpt_to_use=file_path) - current_model_path = file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) - -def thread_render(): - global current_state, current_state_error, current_model_path - from sd_internal import runtime - current_state = ServerStates.Online - preload_model() - while True: - task_cache.clean() - if isinstance(current_state_error, SystemExit): - current_state = ServerStates.Unavailable - return - task = None - try: - task = tasks_queue.get(timeout=1) - except queue.Empty as e: - if isinstance(current_state_error, SystemExit): - current_state = ServerStates.Unavailable - return - else: continue - #if current_model_path != task.request.use_stable_diffusion_model: - # preload_model(task.request.use_stable_diffusion_model) - if current_state_error: - task.error = current_state_error - continue - print(f'Session {task.request.session_id} starting task {id(task)}') - try: - task.lock.acquire(blocking=False) - res = runtime.mk_img(task.request) - if current_model_path == task.request.use_stable_diffusion_model: - current_state = ServerStates.Rendering - else: - current_state = ServerStates.LoadingModel - except Exception as e: - task.error = e - task.lock.release() - tasks_queue.task_done() - print(traceback.format_exc()) - continue - dataQueue = None - if task.request.stream_progress_updates: - dataQueue = task.buffer_queue - for result in res: - if current_state == ServerStates.LoadingModel: - current_state = ServerStates.Rendering - current_model_path = task.request.use_stable_diffusion_model - if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): - runtime.stop_processing = True - if isinstance(current_state_error, StopAsyncIteration): - task.error = current_state_error - current_state_error = None - print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') - if dataQueue: - dataQueue.put(result) - if isinstance(result, str): - result = json.loads(result) - task.response = result - if 'output' in result: - for out_obj in result['output']: - if 'path' in out_obj: - img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] - task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] - elif 'data' in out_obj: - task.temp_images[result['output'].index(out_obj)] = out_obj['data'] - task_cache.keep(task.request.session_id, TASK_TTL) - # Task completed - task.lock.release() - tasks_queue.task_done() - task_cache.keep(task.request.session_id, TASK_TTL) - if isinstance(task.error, StopAsyncIteration): - print(f'Session {task.request.session_id} task {id(task)} cancelled!') - elif task.error is not None: - print(f'Session {task.request.session_id} task {id(task)} failed!') - else: - print(f'Session {task.request.session_id} task {id(task)} completed.') - current_state = ServerStates.Online -# Start Rendering Thread -render_thread = threading.Thread(target=thread_render) -render_thread.daemon = True -render_thread.start() - -@app.on_event("shutdown") -def shutdown_event(): # Signal render thread to close on shutdown - global current_state_error - current_state_error = SystemExit('Application shutting down.') - # needs to support the legacy installations def get_initial_model_to_load(): custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') @@ -261,7 +56,6 @@ def get_initial_model_to_load(): ckpt_to_use = model_path else: print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt') - return ckpt_to_use def resolve_model_to_use(model_name): @@ -275,24 +69,24 @@ def resolve_model_to_use(model_name): model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) return model_path -def save_model_to_config(model_name): - config = getConfig() - if 'model' not in config: - config['model'] = {} +@app.on_event("shutdown") +def shutdown_event(): # Signal render thread to close on shutdown + task_manager.current_state_error = SystemExit('Application shutting down.') - config['model']['stable-diffusion'] = model_name - setConfig(config) +@app.get('/') +def read_root(): + return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) @app.get('/ping') # Get server and optionally session status. def ping(session_id:str=None): - if not render_thread.is_alive(): # Render thread is dead. - if current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) + if not task_manager.render_thread.is_alive(): # Render thread is dead. + if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) return HTTPException(status_code=500, detail='Render thread is dead.') - if current_state_error and not isinstance(current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) # Alive - response = {'status': str(current_state)} + response = {'status': str(task_manager.current_state)} if session_id: - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if task: response['task'] = id(task) if task.lock.locked(): @@ -309,74 +103,38 @@ def ping(session_id:str=None): response['session'] = 'pending' return JSONResponse(response, headers=NOCACHE_HEADERS) +def save_model_to_config(model_name): + config = getConfig() + if 'model' not in config: + config['model'] = {} + + config['model']['stable-diffusion'] = model_name + setConfig(config) + @app.post('/render') -def render(req : ImageRequest): - if not render_thread.is_alive(): # Render thread is dead - return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error - # Alive, check if task in cache - task = task_cache.tryGet(req.session_id) - if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one. - return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable - # - from sd_internal import runtime - r = Request() - r.session_id = req.session_id - r.prompt = req.prompt - r.negative_prompt = req.negative_prompt - r.init_image = req.init_image - r.mask = req.mask - r.num_outputs = req.num_outputs - r.num_inference_steps = req.num_inference_steps - r.guidance_scale = req.guidance_scale - r.width = req.width - r.height = req.height - r.seed = req.seed - r.prompt_strength = req.prompt_strength - r.sampler = req.sampler - # r.allow_nsfw = req.allow_nsfw - r.turbo = req.turbo - r.use_cpu = req.use_cpu - r.use_full_precision = req.use_full_precision - r.save_to_disk_path = req.save_to_disk_path - r.use_upscale: str = req.use_upscale - r.use_face_correction = req.use_face_correction - r.show_only_filtered_image = req.show_only_filtered_image - r.output_format = req.output_format - - r.stream_progress_updates = True # the underlying implementation only supports streaming - r.stream_image_progress = req.stream_image_progress - - r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) - - save_model_to_config(req.use_stable_diffusion_model) - - if not req.stream_progress_updates: - r.stream_image_progress = False - - new_task = RenderTask(r) - task_cache.put(r.session_id, new_task, TASK_TTL) - tasks_queue.put(new_task) - - response = { - 'status': str(current_state), - 'queue': tasks_queue.qsize(), - 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', - 'task': id(new_task) - } - return JSONResponse(response, headers=NOCACHE_HEADERS) - -async def read_data_generator(data:queue.Queue, lock:threading.Lock): +def render(req : task_manager.ImageRequest): try: - while not data.empty(): - res = data.get(block=False) - data.task_done() - yield res - except queue.Empty as e: yield + save_model_to_config(req.use_stable_diffusion_model) + req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) + new_task = task_manager.render(req) + response = { + 'status': str(task_manager.current_state), + 'queue': task_manager.tasks_queue.qsize(), + 'stream': f'/image/stream/{req.session_id}/{id(new_task)}', + 'task': id(new_task) + } + return JSONResponse(response, headers=NOCACHE_HEADERS) + except ChildProcessError as e: # Render thread is dead + return HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error + except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one. + return HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable + except Exception as e: + return HTTPException(status_code=500, detail=str(e)) @app.get('/image/stream/{session_id:str}/{task_id:int}') def stream(session_id:str, task_id:int): #TODO Move to WebSockets ?? - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if not task: return HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone if (id(task) != task_id): return HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): @@ -385,17 +143,16 @@ def stream(session_id:str, task_id:int): return JSONResponse(task.response, headers=NOCACHE_HEADERS) return HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') - return StreamingResponse(read_data_generator(task.buffer_queue, task.lock), media_type='application/json') + return StreamingResponse(task.read_buffer_generator(), media_type='application/json') @app.get('/image/stop') def stop(session_id:str=None): if not session_id: - if current_state == ServerStates.Online or current_state == ServerStates.Unavailable: + if task_manager.current_state == ServerStates.Online or task_manager.current_state == ServerStates.Unavailable: return HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict - global current_state_error - current_state_error = StopAsyncIteration() + task_manager.current_state_error = StopAsyncIteration('') return {'OK'} - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if not task: return HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found if isinstance(task.error, StopAsyncIteration): return HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict task.error = StopAsyncIteration('') @@ -403,7 +160,7 @@ def stop(session_id:str=None): @app.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): - task = task_cache.tryGet(session_id) + task = task_manager.task_cache.tryGet(session_id) if not task: return HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone if not task.temp_images[img_id]: return HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early try: @@ -520,5 +277,8 @@ class LogSuppressFilter(logging.Filter): return True logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) +task_manager.default_model_to_load = get_initial_model_to_load() +task_manager.start_render_thread() + # start the browser ui import webbrowser; webbrowser.open('http://localhost:9000') \ No newline at end of file