Merge pull request #422 from madrang/device-select

Implement complete device selection in the backend.
This commit is contained in:
cmdr2 2022-11-02 12:05:59 +05:30 committed by GitHub
commit 976bc727dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 38 deletions

View File

@ -18,7 +18,6 @@ class Request:
precision: str = "autocast" # or "full" precision: str = "autocast" # or "full"
save_to_disk_path: str = None save_to_disk_path: str = None
turbo: bool = True turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
@ -50,7 +49,7 @@ class Request:
"output_format": self.output_format, "output_format": self.output_format,
} }
def to_string(self): def __str__(self):
return f''' return f'''
session_id: {self.session_id} session_id: {self.session_id}
prompt: {self.prompt} prompt: {self.prompt}
@ -64,7 +63,6 @@ class Request:
precision: {self.precision} precision: {self.precision}
save_to_disk_path: {self.save_to_disk_path} save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo} turbo: {self.turbo}
use_cpu: {self.use_cpu}
use_full_precision: {self.use_full_precision} use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction} use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale} use_upscale: {self.use_upscale}

View File

@ -45,6 +45,25 @@ from io import BytesIO
from threading import local as LocalThreadVars from threading import local as LocalThreadVars
thread_data = LocalThreadVars() thread_data = LocalThreadVars()
def get_processor_name():
try:
import platform, subprocess
if platform.system() == "Windows":
return platform.processor()
elif platform.system() == "Darwin":
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'
command ="sysctl -n machdep.cpu.brand_string"
return subprocess.check_output(command).strip()
elif platform.system() == "Linux":
command = "cat /proc/cpuinfo"
all_info = subprocess.check_output(command, shell=True).decode().strip()
for line in all_info.split("\n"):
if "model name" in line:
return re.sub( ".*model name.*:", "", line,1).strip()
except:
print(traceback.format_exc())
return "cpu"
def device_would_fail(device): def device_would_fail(device):
if device == 'cpu': return None if device == 'cpu': return None
# Returns None when no issues found, otherwise returns the detected error str. # Returns None when no issues found, otherwise returns the detected error str.
@ -68,17 +87,17 @@ def device_select(device):
print(failure_msg) print(failure_msg)
return False return False
device_name = torch.cuda.get_device_name(device) thread_data.device_name = torch.cuda.get_device_name(device)
thread_data.device = device
# otherwise these NVIDIA cards create green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name) device_name = thread_data.device_name.lower()
thread_data.force_full_precision = ('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)
if thread_data.force_full_precision: if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', device_name) print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name)
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
thread_data.precision = 'full' thread_data.precision = 'full'
thread_data.device = device
thread_data.has_valid_gpu = True
return True return True
def device_init(device_selection=None): def device_init(device_selection=None):
@ -100,24 +119,26 @@ def device_init(device_selection=None):
thread_data.model_is_half = False thread_data.model_is_half = False
thread_data.model_fs_is_half = False thread_data.model_fs_is_half = False
thread_data.device = None thread_data.device = None
thread_data.device_name = None
thread_data.unet_bs = 1 thread_data.unet_bs = 1
thread_data.precision = 'autocast' thread_data.precision = 'autocast'
thread_data.sampler_plms = None thread_data.sampler_plms = None
thread_data.sampler_ddim = None thread_data.sampler_ddim = None
thread_data.turbo = False thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False thread_data.force_full_precision = False
thread_data.reduced_memory = True thread_data.reduced_memory = True
if device_selection.lower() == 'cpu': if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
thread_data.device = 'cpu' thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
print('Render device CPU available as', thread_data.device_name)
return return
if not torch.cuda.is_available(): if not torch.cuda.is_available():
if device_selection == 'auto' or device_selection == 'current': if device_selection == 'auto' or device_selection == 'current':
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!') print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
thread_data.device = 'cpu' thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
return return
else: else:
raise EnvironmentError('torch.cuda is not available.') raise EnvironmentError('torch.cuda is not available.')
@ -162,6 +183,7 @@ def device_init(device_selection=None):
return return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!') print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
thread_data.device = 'cpu' thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
def is_first_cuda_device(device): def is_first_cuda_device(device):
if device is None: return False if device is None: return False
@ -475,7 +497,7 @@ def do_mk_img(req: Request):
thread_data.vae_file = req.use_vae_model thread_data.vae_file = req.use_vae_model
needs_model_reload = True needs_model_reload = True
if thread_data.has_valid_gpu: if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast' thread_data.precision = 'full' if req.use_full_precision else 'autocast'
@ -500,7 +522,7 @@ def do_mk_img(req: Request):
opt_f = 8 opt_f = 8
opt_ddim_eta = 0.0 opt_ddim_eta = 0.0
print(req.to_string(), '\n device', thread_data.device) print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name)
print('\n\n Using precision:', thread_data.precision) print('\n\n Using precision:', thread_data.precision)
seed_everything(opt_seed) seed_everything(opt_seed)

View File

@ -21,6 +21,7 @@ LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
CPU_UNLOAD_TIMEOUT = 4 * 60 # seconds - Idle time before CPU unload resource when GPUs are present.
class SymbolClass(type): # Print nicely formatted Symbol names. class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__ def __repr__(self): return self.__qualname__
@ -38,6 +39,7 @@ class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request): def __init__(self, req: Request):
self.request: Request = req # Initial Request self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse self.response: Any = None # Copy of the last reponse
self.render_device = None
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.error: Exception = None self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
@ -68,7 +70,8 @@ class ImageRequest(BaseModel):
# allow_nsfw: bool = False # allow_nsfw: bool = False
save_to_disk_path: str = None save_to_disk_path: str = None
turbo: bool = True turbo: bool = True
use_cpu: bool = False use_cpu: bool = False ##TODO Remove after UI and plugins transition.
render_device: str = None
use_full_precision: bool = False use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
@ -89,7 +92,7 @@ class FilterRequest(BaseModel):
height: int = 512 height: int = 512
save_to_disk_path: str = None save_to_disk_path: str = None
turbo: bool = True turbo: bool = True
use_cpu: bool = False render_device: str = None
use_full_precision: bool = False use_full_precision: bool = False
output_format: str = "jpeg" # or "png" output_format: str = "jpeg" # or "png"
@ -219,26 +222,24 @@ def thread_get_next_task():
queued_task.error = Exception('cuda:0 is not available with the current config. Remove GFPGANer filter to run task.') queued_task.error = Exception('cuda:0 is not available with the current config. Remove GFPGANer filter to run task.')
task = queued_task task = queued_task
break break
if queued_task.request.use_cpu: if queued_task.render_device == 'cpu':
queued_task.error = Exception('Cpu cannot be used to run this task. Remove GFPGANer filter to run task.') queued_task.error = Exception('Cpu cannot be used to run this task. Remove GFPGANer filter to run task.')
task = queued_task task = queued_task
break break
if not runtime.is_first_cuda_device(runtime.thread_data.device): if not runtime.is_first_cuda_device(runtime.thread_data.device):
continue # Wait for cuda:0 continue # Wait for cuda:0
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu': if queued_task.render_device and runtime.thread_data.device != queued_task.render_device:
if is_alive('cpu') > 0: # Is asking for a specific render device.
continue # CPU Tasks, Skip GPU device if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one.
else: else:
queued_task.error = Exception('Cpu is not enabled in render_devices.') # Requested device is not active, return error to UI.
task = queued_task queued_task.error = Exception(str(queued_task.render_device) + ' is not currently active.')
break
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
if is_alive() > 1: # cpu is alive, so need more than one.
continue # GPU Tasks, don't run on CPU unless there is nothing else.
else:
queued_task.error = Exception('No active gpu found. Please check the error message in the command-line window at startup.')
task = queued_task task = queued_task
break break
if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
break break
if task is not None: if task is not None:
@ -252,11 +253,15 @@ def thread_render(device):
from . import runtime from . import runtime
try: try:
runtime.device_init(device) runtime.device_init(device)
except: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
'error': e
}
return return
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device 'device': runtime.thread_data.device,
'device_name': runtime.thread_data.device_name
} }
if runtime.thread_data.device != 'cpu' or is_alive() == 1: if runtime.thread_data.device != 'cpu' or is_alive() == 1:
preload_model() preload_model()
@ -268,6 +273,11 @@ def thread_render(device):
return return
task = thread_get_next_task() task = thread_get_next_task()
if task is None: if task is None:
if runtime.thread_data.device == 'cpu' and is_alive() > 1 and hasattr(runtime.thread_data, 'lastActive') and time.time() - runtime.thread_data.lastActive > CPU_UNLOAD_TIMEOUT:
# GPUs present and CPU is idle. Unload resources.
runtime.unload_models()
runtime.unload_filters()
del runtime.thread_data.lastActive
time.sleep(1) time.sleep(1)
continue continue
if task.error is not None: if task.error is not None:
@ -280,9 +290,12 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
print(f'Session {task.request.session_id} starting task {id(task)}') print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
if runtime.thread_data.device == 'cpu' and is_alive() > 1:
# CPU is not the only device. Keep track of active time to unload resources later.
runtime.thread_data.lastActive = time.time()
# Open data generator. # Open data generator.
res = runtime.mk_img(task.request) res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model: if current_model_path == task.request.use_stable_diffusion_model:
@ -331,7 +344,7 @@ def thread_render(device):
elif task.error is not None: elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!') print(f'Session {task.request.session_id} task {id(task)} failed!')
else: else:
print(f'Session {task.request.session_id} task {id(task)} completed.') print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.')
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(session_id:str, update_ttl:bool=False): def get_cached_task(session_id:str, update_ttl:bool=False):
@ -341,6 +354,21 @@ def get_cached_task(session_id:str, update_ttl:bool=False):
return None return None
return task_cache.tryGet(session_id) return task_cache.tryGet(session_id)
def get_devices():
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED)
try:
device_dict = {}
for rthread in render_threads:
if not rthread.is_alive():
continue
weak_data = weak_thread_data.get(rthread)
if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data:
continue
device_dict.update({weak_data['device']:weak_data['device_name']})
return device_dict
finally:
manager_lock.release()
def is_first_cuda_device(device): def is_first_cuda_device(device):
from . import runtime # When calling runtime from outside thread_render DO NOT USE thread specific attributes or functions. from . import runtime # When calling runtime from outside thread_render DO NOT USE thread specific attributes or functions.
return runtime.is_first_cuda_device(device) return runtime.is_first_cuda_device(device)
@ -352,8 +380,7 @@ def is_alive(name=None):
for rthread in render_threads: for rthread in render_threads:
if name is not None: if name is not None:
weak_data = weak_thread_data.get(rthread) weak_data = weak_thread_data.get(rthread)
if weak_data is None or weak_data['device'] is None: if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
print('The thread', rthread.name, 'is registered but has no data store in the task manager.')
continue continue
thread_name = str(weak_data['device']).lower() thread_name = str(weak_data['device']).lower()
if is_first_cuda_device(name): if is_first_cuda_device(name):
@ -380,6 +407,8 @@ def start_render_thread(device='auto'):
manager_lock.release() manager_lock.release()
timeout = DEVICE_START_TIMEOUT timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
return False
if timeout <= 0: if timeout <= 0:
return False return False
timeout -= 1 timeout -= 1
@ -416,7 +445,6 @@ def render(req : ImageRequest):
r.sampler = req.sampler r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw # r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo r.turbo = req.turbo
r.use_cpu = req.use_cpu
r.use_full_precision = req.use_full_precision r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale r.use_upscale: str = req.use_upscale
@ -433,6 +461,8 @@ def render(req : ImageRequest):
r.stream_image_progress = False r.stream_image_progress = False
new_task = RenderTask(r) new_task = RenderTask(r)
new_task.render_device = req.render_device
if task_cache.put(r.session_id, new_task, TASK_TTL): if task_cache.put(r.session_id, new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests. # Use twice the normal timeout for adding user requests.
# Tries to force task_cache.put to fail before tasks_queue.put would. # Tries to force task_cache.put to fail before tasks_queue.put would.

View File

@ -271,6 +271,8 @@ def read_web_data(key:str=None):
if config is None: if config is None:
raise HTTPException(status_code=500, detail="Config file is missing or unreadable") raise HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS) return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'devices':
return JSONResponse(task_manager.get_devices(), headers=NOCACHE_HEADERS)
elif key == 'models': elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS) return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
@ -315,7 +317,11 @@ def save_model_to_config(model_name):
@app.post('/render') @app.post('/render')
def render(req : task_manager.ImageRequest): def render(req : task_manager.ImageRequest):
if req.use_cpu and task_manager.is_alive('cpu') <= 0: raise HTTPException(status_code=403, detail=f'CPU rendering is not enabled in config.json or the thread has died...') # HTTP403 Forbidden if req.use_cpu: # TODO Remove after transition.
print('WARNING Replace {use_cpu: true} by {render_device: "cpu"}')
req.render_device = 'cpu'
del req.use_cpu
if req.render_device and task_manager.is_alive(req.render_device) <= 0: raise HTTPException(status_code=403, detail=f'{req.render_device} rendering is not enabled in config.json or the thread has died...') # HTTP403 Forbidden
if req.use_face_correction and task_manager.is_alive(0) <= 0: #TODO Remove when GFPGANer is fixed upstream. if req.use_face_correction and task_manager.is_alive(0) <= 0: #TODO Remove when GFPGANer is fixed upstream.
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
try: try:
@ -401,19 +407,24 @@ config = getConfig()
# Start the task_manager # Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use() task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load) task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load)
display_warning = False
if 'render_devices' in config: # Start a new thread for each device. if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str): if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',') config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list): if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.') raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']: for device in config['render_devices']:
if task_manager.is_alive(device) >= 1:
print(device, 'already registered.')
continue
if not task_manager.start_render_thread(device): if not task_manager.start_render_thread(device):
print(device, 'failed to start.') print(device, 'failed to start.')
if task_manager.is_alive() <= 0: # No running devices, probably invalid user config. if task_manager.is_alive() <= 0: # No running devices, probably invalid user config.
print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json') print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json')
print('Loading default render devices to replace invalid render_devices field from config', config['render_devices']) print('Loading default render devices to replace invalid render_devices field from config', config['render_devices'])
elif task_manager.is_alive(0) <= 0: # Missing GPU:0
display_warning = True # Warn user to update settings...
display_warning = False
if task_manager.is_alive() <= 0: # Either no defauls or no devices after loading config. if task_manager.is_alive() <= 0: # Either no defauls or no devices after loading config.
# Select best GPU device using free memory, if more than one device. # Select best GPU device using free memory, if more than one device.
if task_manager.start_render_thread('auto'): # Detect best device for renders if task_manager.start_render_thread('auto'): # Detect best device for renders
@ -431,12 +442,11 @@ if task_manager.is_alive() <= 0: # Either no defauls or no devices after loading
if not task_manager.start_render_thread('cpu'): if not task_manager.start_render_thread('cpu'):
print('Failed to start CPU render device...') print('Failed to start CPU render device...')
if display_warning or task_manager.is_alive(0) <= 0: if display_warning:
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer') print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat') print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh') print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
del display_warning del display_warning
# start the browser ui # start the browser ui