diff --git a/ui/server.py b/ui/server.py index 18d5214e..1988cab2 100644 --- a/ui/server.py +++ b/ui/server.py @@ -162,16 +162,16 @@ def preload_model(file_path=None): print(traceback.format_exc()) def thread_render(): - global current_state, current_state_error + global current_state, current_state_error, current_model_path from sd_internal import runtime current_state = ServerStates.Online preload_model() while True: task_cache.clean() - task = None if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable return + task = None try: task = tasks_queue.get(timeout=1) except queue.Empty as e: @@ -185,10 +185,13 @@ def thread_render(): task.error = current_state_error continue print(f'Session {task.request.session_id} starting task {id(task)}') - current_state = ServerStates.Rendering try: task.lock.acquire(blocking=False) res = runtime.mk_img(task.request) + if current_model_path == task.request.use_stable_diffusion_model: + current_state = ServerStates.Rendering + else: + current_state = ServerStates.LoadingModel except Exception as e: task.error = e task.lock.release() @@ -199,6 +202,9 @@ def thread_render(): if task.request.stream_progress_updates: dataQueue = task.buffer_queue for result in res: + if current_state == ServerStates.LoadingModel: + current_state = ServerStates.Rendering + current_model_path = task.request.use_stable_diffusion_model if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): runtime.stop_processing = True if isinstance(current_state_error, StopAsyncIteration):