mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-20 09:57:49 +02:00
Fixed is_alive with render_threads that can update the device name after starting.
This commit is contained in:
parent
940236b4a4
commit
5e461e9b6b
@ -3,7 +3,7 @@ import traceback
|
|||||||
|
|
||||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
|
|
||||||
import queue, threading, time
|
import queue, threading, time, weakref
|
||||||
from typing import Any, Generator, Hashable, Optional, Union
|
from typing import Any, Generator, Hashable, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -165,6 +165,7 @@ current_model_path = None
|
|||||||
tasks_queue = []
|
tasks_queue = []
|
||||||
task_cache = TaskCache()
|
task_cache = TaskCache()
|
||||||
default_model_to_load = None
|
default_model_to_load = None
|
||||||
|
weak_thread_data = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
def preload_model(file_path=None):
|
def preload_model(file_path=None):
|
||||||
global current_state, current_state_error, current_model_path
|
global current_state, current_state_error, current_model_path
|
||||||
@ -189,11 +190,17 @@ def preload_model(file_path=None):
|
|||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error, current_model_path
|
global current_state, current_state_error, current_model_path
|
||||||
from . import runtime
|
from . import runtime
|
||||||
|
weak_thread_data[threading.current_thread()] = {
|
||||||
|
'device': device
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
runtime.device_init(device)
|
runtime.device_init(device)
|
||||||
except:
|
except:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return
|
return
|
||||||
|
weak_thread_data[threading.current_thread()] = {
|
||||||
|
'device': runtime.thread_data.device
|
||||||
|
}
|
||||||
preload_model()
|
preload_model()
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
while True:
|
while True:
|
||||||
@ -308,8 +315,11 @@ def is_alive(name=None):
|
|||||||
nbr_alive = 0
|
nbr_alive = 0
|
||||||
try:
|
try:
|
||||||
for rthread in render_threads:
|
for rthread in render_threads:
|
||||||
thread_name = rthread.name[len(THREAD_NAME_PREFIX):].lower()
|
|
||||||
if name is not None:
|
if name is not None:
|
||||||
|
weak_data = weak_thread_data.get(rthread)
|
||||||
|
if weak_data is None or weak_data['device'] is None:
|
||||||
|
continue
|
||||||
|
thread_name = str(weak_data['device']).lower()
|
||||||
if is_first_cuda_device(name):
|
if is_first_cuda_device(name):
|
||||||
if not is_first_cuda_device(thread_name):
|
if not is_first_cuda_device(thread_name):
|
||||||
continue
|
continue
|
||||||
|
16
ui/server.py
16
ui/server.py
@ -26,6 +26,7 @@ APP_CONFIG_DEFAULT_MODELS = [
|
|||||||
'sd-v1-4', # Default fallback.
|
'sd-v1-4', # Default fallback.
|
||||||
]
|
]
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
@ -36,6 +37,7 @@ from typing import Any, Generator, Hashable, List, Optional, Union
|
|||||||
|
|
||||||
from sd_internal import Request, Response, task_manager
|
from sd_internal import Request, Response, task_manager
|
||||||
|
|
||||||
|
LOOP = asyncio.get_event_loop()
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
modifiers_cache = None
|
modifiers_cache = None
|
||||||
@ -348,21 +350,29 @@ if 'render_devices' in config: # Start a new thread for each device.
|
|||||||
for device in config['render_devices']:
|
for device in config['render_devices']:
|
||||||
task_manager.start_render_thread(device)
|
task_manager.start_render_thread(device)
|
||||||
|
|
||||||
|
async def check_status():
|
||||||
|
device_count = 0
|
||||||
|
for i in range(10): # Wait for devices to register and/or change names.
|
||||||
|
new_count = task_manager.is_alive()
|
||||||
|
if device_count != new_count:
|
||||||
|
device_count = new_count
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
else:
|
||||||
|
break;
|
||||||
allow_cpu = False
|
allow_cpu = False
|
||||||
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
|
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
|
||||||
# Select best device GPU device using free memory if more than one device.
|
# Select best device GPU device using free memory if more than one device.
|
||||||
task_manager.start_render_thread('auto')
|
task_manager.start_render_thread('auto')
|
||||||
allow_cpu = True
|
allow_cpu = True
|
||||||
|
|
||||||
# Allow CPU to be used for renders if not already enabled in current config.
|
# Allow CPU to be used for renders if not already enabled in current config.
|
||||||
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
|
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
|
||||||
task_manager.start_render_thread('cpu')
|
task_manager.start_render_thread('cpu')
|
||||||
|
if not task_manager.is_alive(0) <= 0:
|
||||||
if task_manager.is_alive(0) <= 0: # Missing cuda:0, warn the user.
|
|
||||||
print('WARNING: GFPGANer only works on CPU or GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
|
print('WARNING: GFPGANer only works on CPU or GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
|
||||||
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
|
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
|
||||||
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
|
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
|
||||||
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
|
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
|
||||||
|
LOOP.create_task(check_status())
|
||||||
|
|
||||||
# start the browser ui
|
# start the browser ui
|
||||||
import webbrowser; webbrowser.open('http://localhost:9000')
|
import webbrowser; webbrowser.open('http://localhost:9000')
|
Loading…
x
Reference in New Issue
Block a user