Merge pull request #407 from madrang/beta

Wait until device is fully ready before proceeding.
This commit is contained in:
cmdr2 2022-10-28 10:28:16 +05:30 committed by GitHub
commit d5a012d49f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 45 deletions

View File

@ -20,6 +20,8 @@ ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
class SymbolClass(type): # Print nicely formatted Symbol names. class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__ def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__ def __str__(self): return self.__name__
@ -240,9 +242,6 @@ def thread_get_next_task():
def thread_render(device): def thread_render(device):
global current_state, current_state_error, current_model_path global current_state, current_state_error, current_model_path
from . import runtime from . import runtime
weak_thread_data[threading.current_thread()] = {
'device': device
}
try: try:
runtime.device_init(device) runtime.device_init(device)
except: except:
@ -251,7 +250,7 @@ def thread_render(device):
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device 'device': runtime.thread_data.device
} }
if runtime.thread_data.device != 'cpu': if runtime.thread_data.device != 'cpu' or is_alive() == 1:
preload_model() preload_model()
current_state = ServerStates.Online current_state = ServerStates.Online
while True: while True:
@ -367,8 +366,8 @@ def start_render_thread(device='auto'):
rthread.daemon = True rthread.daemon = True
rthread.name = THREAD_NAME_PREFIX + device rthread.name = THREAD_NAME_PREFIX + device
rthread.start() rthread.start()
timeout = LOCK_TIMEOUT timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive(): while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
if timeout <= 0: raise Exception('render_thread', rthread.name, 'failed to start before timeout or has crashed.') if timeout <= 0: raise Exception('render_thread', rthread.name, 'failed to start before timeout or has crashed.')
timeout -= 1 timeout -= 1
time.sleep(1) time.sleep(1)

View File

@ -31,7 +31,6 @@ APP_CONFIG_DEFAULT_MODELS = [
'sd-v1-4', # Default fallback. 'sd-v1-4', # Default fallback.
] ]
import asyncio
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
@ -42,7 +41,6 @@ from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager from sd_internal import Request, Response, task_manager
LOOP = asyncio.get_event_loop()
app = FastAPI() app = FastAPI()
modifiers_cache = None modifiers_cache = None
@ -362,54 +360,37 @@ logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
config = getConfig() config = getConfig()
async def check_status(): # Task to Validate user config shortly after startup.
# Check that the loaded config.json yielded a server in a known valid state.
# When issues are found, try to fix them when possible and warn the user.
device_count = 0
# Wait for devices to register and/or change names.
THREAD_START_DELAY = 5 # seconds - Give time for devices/threads to start.
for i in range(10): # Maximum number of retry.
await asyncio.sleep(THREAD_START_DELAY)
new_count = task_manager.is_alive()
# Stops retry once no more devices show up.
if new_count > 0 and device_count == new_count: break
device_count = new_count
if 'render_devices' in config and task_manager.is_alive() <= 0: # No running devices, probably invalid user config. Try to apply defaults.
print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json')
task_manager.start_render_thread('auto') # Detect best device for renders
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
print('Default render devices loaded to replace missing render_devices', config['render_devices'])
display_warning = False
if not 'render_devices' in config and task_manager.is_alive(0) <= 0: # No config set, is on auto mode and without cuda:0
task_manager.start_render_thread('cuda') # An other cuda device is better and cuda:0 is missing, start it...
display_warning = True # And warn user to update settings...
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
if display_warning or task_manager.is_alive(0) <= 0:
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
# Start the task_manager # Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use() task_manager.default_model_to_load = resolve_model_to_use()
if 'render_devices' in config: # Start a new thread for each device. if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str): if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',') config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list): if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.') raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']: for device in config['render_devices']:
task_manager.start_render_thread(device) task_manager.start_render_thread(device)
else: if task_manager.is_alive() <= 0: # No running devices, probably invalid user config.
print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json')
print('Loading default render devices to replace invalid render_devices field from config', config['render_devices'])
display_warning = False
if task_manager.is_alive() <= 0: # Either no defauls or no devices after loading config.
# Select best GPU device using free memory, if more than one device. # Select best GPU device using free memory, if more than one device.
task_manager.start_render_thread('auto') # Detect best device for renders task_manager.start_render_thread('auto') # Detect best device for renders
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders if task_manager.is_alive(0) <= 0: # without cuda:0
task_manager.start_render_thread('cuda') # An other cuda device is better and cuda:0 is missing, start it...
display_warning = True # And warn user to update settings...
if task_manager.is_alive('cpu') <= 0:
# Allow CPU to be used for renders
task_manager.start_render_thread('cpu')
# Task to Validate user config shortly after startup. if display_warning or task_manager.is_alive(0) <= 0:
LOOP.create_task(check_status()) print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
del display_warning
# start the browser ui # start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000') import webbrowser; webbrowser.open('http://localhost:9000')