Fixed is_alive with render_threads that can update the device name after starting.

This commit is contained in:
Marc-Andre Ferland 2022-10-18 13:21:15 -04:00
parent 940236b4a4
commit 5e461e9b6b
2 changed files with 37 additions and 17 deletions

View File

@ -3,7 +3,7 @@ import traceback
TASK_TTL = 15 * 60 # Discard last session's task timeout TASK_TTL = 15 * 60 # Discard last session's task timeout
import queue, threading, time import queue, threading, time, weakref
from typing import Any, Generator, Hashable, Optional, Union from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -165,6 +165,7 @@ current_model_path = None
tasks_queue = [] tasks_queue = []
task_cache = TaskCache() task_cache = TaskCache()
default_model_to_load = None default_model_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(file_path=None): def preload_model(file_path=None):
global current_state, current_state_error, current_model_path global current_state, current_state_error, current_model_path
@ -189,11 +190,17 @@ def preload_model(file_path=None):
def thread_render(device): def thread_render(device):
global current_state, current_state_error, current_model_path global current_state, current_state_error, current_model_path
from . import runtime from . import runtime
weak_thread_data[threading.current_thread()] = {
'device': device
}
try: try:
runtime.device_init(device) runtime.device_init(device)
except: except:
print(traceback.format_exc()) print(traceback.format_exc())
return return
weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device
}
preload_model() preload_model()
current_state = ServerStates.Online current_state = ServerStates.Online
while True: while True:
@ -308,8 +315,11 @@ def is_alive(name=None):
nbr_alive = 0 nbr_alive = 0
try: try:
for rthread in render_threads: for rthread in render_threads:
thread_name = rthread.name[len(THREAD_NAME_PREFIX):].lower()
if name is not None: if name is not None:
weak_data = weak_thread_data.get(rthread)
if weak_data is None or weak_data['device'] is None:
continue
thread_name = str(weak_data['device']).lower()
if is_first_cuda_device(name): if is_first_cuda_device(name):
if not is_first_cuda_device(thread_name): if not is_first_cuda_device(thread_name):
continue continue

View File

@ -26,6 +26,7 @@ APP_CONFIG_DEFAULT_MODELS = [
'sd-v1-4', # Default fallback. 'sd-v1-4', # Default fallback.
] ]
import asyncio
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
@ -36,6 +37,7 @@ from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager from sd_internal import Request, Response, task_manager
LOOP = asyncio.get_event_loop()
app = FastAPI() app = FastAPI()
modifiers_cache = None modifiers_cache = None
@ -348,21 +350,29 @@ if 'render_devices' in config: # Start a new thread for each device.
for device in config['render_devices']: for device in config['render_devices']:
task_manager.start_render_thread(device) task_manager.start_render_thread(device)
allow_cpu = False async def check_status():
if task_manager.is_alive() <= 0: # No running devices, apply defaults. device_count = 0
# Select best device GPU device using free memory if more than one device. for i in range(10): # Wait for devices to register and/or change names.
task_manager.start_render_thread('auto') new_count = task_manager.is_alive()
allow_cpu = True if device_count != new_count:
device_count = new_count
# Allow CPU to be used for renders if not already enabled in current config. await asyncio.sleep(3)
if task_manager.is_alive('cpu') <= 0 and allow_cpu: else:
task_manager.start_render_thread('cpu') break;
allow_cpu = False
if task_manager.is_alive(0) <= 0: # Missing cuda:0, warn the user. if task_manager.is_alive() <= 0: # No running devices, apply defaults.
print('WARNING: GFPGANer only works on CPU or GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # Select best device GPU device using free memory if more than one device.
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer') task_manager.start_render_thread('auto')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat') allow_cpu = True
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh') # Allow CPU to be used for renders if not already enabled in current config.
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
task_manager.start_render_thread('cpu')
if not task_manager.is_alive(0) <= 0:
print('WARNING: GFPGANer only works on CPU or GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
LOOP.create_task(check_status())
# start the browser ui # start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000') import webbrowser; webbrowser.open('http://localhost:9000')