Fixed bug in task_manager.is_alive and added way to check for first device.

This commit is contained in:
Marc-Andre Ferland 2022-10-16 23:06:41 -04:00
parent 994d62ac65
commit 41bfb96b6b
2 changed files with 11 additions and 9 deletions

View File

@ -149,12 +149,10 @@ def device_init(device_selection=None):
thread_data.device = 'cpu'
def is_first_cuda_device(device):
if thread_data.device == 0 or thread_data.device == '0':
return True
if thread_data.device == 'cuda' or thread_data.device == 'cuda:0':
return True
if thread_data.device == torch.device(0):
return True
if device is None: return False
if device == 0 or device == '0': return True
if device == 'cuda' or device == 'cuda:0': return True
if device == torch.device(0): return True
return False
def load_model_ckpt():

View File

@ -9,6 +9,7 @@ from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel
from sd_internal import Request, Response
THREAD_NAME_PREFIX = 'Runtime-Render/'
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
@ -285,12 +286,15 @@ def thread_render(device):
current_state = ServerStates.Online
def is_alive(name=None):
from . import runtime # When calling runtime from here DO NOT USE thread specific attributes or functions.
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
nbr_alive = 0
try:
for rthread in render_threads:
if name and not rthread.name.endswith(name):
continue
thread_name = rthread.name[len(THREAD_NAME_PREFIX):]
if name and thread_name != name:
if not runtime.is_first_cuda_device(name) and not runtime.is_first_cuda_device(thread_name):
continue
if rthread.is_alive():
nbr_alive += 1
return nbr_alive
@ -303,7 +307,7 @@ def start_render_thread(device='auto'):
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
rthread.name = 'Runner/' + device
rthread.name = THREAD_NAME_PREFIX + device
rthread.start()
render_threads.append(rthread)
finally: