diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 81baa103..e9131b4d 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -149,12 +149,10 @@ def device_init(device_selection=None): thread_data.device = 'cpu' def is_first_cuda_device(device): - if thread_data.device == 0 or thread_data.device == '0': - return True - if thread_data.device == 'cuda' or thread_data.device == 'cuda:0': - return True - if thread_data.device == torch.device(0): - return True + if device is None: return False + if device == 0 or device == '0': return True + if device == 'cuda' or device == 'cuda:0': return True + if device == torch.device(0): return True return False def load_model_ckpt(): diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 75e1d262..8e61b7ea 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -9,6 +9,7 @@ from typing import Any, Generator, Hashable, Optional, Union from pydantic import BaseModel from sd_internal import Request, Response +THREAD_NAME_PREFIX = 'Runtime-Render/' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. @@ -285,12 +286,15 @@ def thread_render(device): current_state = ServerStates.Online def is_alive(name=None): + from . import runtime # When calling runtime from here DO NOT USE thread specific attributes or functions. if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED) nbr_alive = 0 try: for rthread in render_threads: - if name and not rthread.name.endswith(name): - continue + thread_name = rthread.name[len(THREAD_NAME_PREFIX):] + if name and thread_name != name: + if not runtime.is_first_cuda_device(name) and not runtime.is_first_cuda_device(thread_name): + continue if rthread.is_alive(): nbr_alive += 1 return nbr_alive @@ -303,7 +307,7 @@ def start_render_thread(device='auto'): try: rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread.daemon = True - rthread.name = 'Runner/' + device + rthread.name = THREAD_NAME_PREFIX + device rthread.start() render_threads.append(rthread) finally: