From 364e364429ad1ca4e082f7e23c31b06876a02715 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Sat, 22 Oct 2022 13:52:13 -0400 Subject: [PATCH] Added get_cached_task to replace task_cache.tryGet in server.py Now updated cache TTL on /stream and temp images endpoints. Keep images alive longer when browser keeps reading the endpoints. --- ui/sd_internal/task_manager.py | 7 +++++++ ui/server.py | 8 ++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 1216c2c0..0d26327e 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -326,6 +326,13 @@ def thread_render(device): print(f'Session {task.request.session_id} task {id(task)} completed.') current_state = ServerStates.Online +def get_cached_task(session_id:str, update_ttl:bool=False): + # By calling keep before tryGet, wont discard if was expired. + if update_ttl and not task_cache.keep(session_id, TASK_TTL): + # Failed to keep task, already gone. + return None + return task_cache.tryGet(session_id) + def is_first_cuda_device(device): from . import runtime # When calling runtime from outside thread_render DO NOT USE thread specific attributes or functions. return runtime.is_first_cuda_device(device) diff --git a/ui/server.py b/ui/server.py index a91d87ce..dc76a3bf 100644 --- a/ui/server.py +++ b/ui/server.py @@ -251,7 +251,7 @@ def ping(session_id:str=None): # Alive response = {'status': str(task_manager.current_state)} if session_id: - task = task_manager.task_cache.tryGet(session_id) + task = task_manager.get_cached_task(session_id) if task: response['task'] = id(task) if task.lock.locked(): @@ -302,7 +302,7 @@ def render(req : task_manager.ImageRequest): @app.get('/image/stream/{session_id:str}/{task_id:int}') def stream(session_id:str, task_id:int): #TODO Move to WebSockets ?? - task = task_manager.task_cache.tryGet(session_id) + task = task_manager.get_cached_task(session_id, update_ttl=True) if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): @@ -320,7 +320,7 @@ def stop(session_id:str=None): raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict task_manager.current_state_error = StopAsyncIteration('') return {'OK'} - task = task_manager.task_cache.tryGet(session_id) + task = task_manager.get_cached_task(session_id) if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict task.error = StopAsyncIteration('') @@ -328,7 +328,7 @@ def stop(session_id:str=None): @app.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): - task = task_manager.task_cache.tryGet(session_id) + task = task_manager.get_cached_task(session_id, update_ttl=True) if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early try: