Handle device init failures and record that as an error, if the GPU has less than 3 gb of VRAM

This commit is contained in:
cmdr2 2022-11-11 16:13:27 +05:30
parent c13bccc7ae
commit 3fdd8d91e2
2 changed files with 13 additions and 19 deletions

View File

@ -144,22 +144,21 @@ def device_init(device_selection):
thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
print('Render device CPU available as', thread_data.device_name)
return
return True
if not torch.cuda.is_available():
if device_selection == 'auto':
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
return
return True
else:
raise EnvironmentError(f'Could not find a compatible GPU for the requested device_selection: {device_selection}!')
if device_selection == 'auto':
device_count = torch.cuda.device_count()
if device_count == 1:
device_select('cuda:0')
if device_count == 1 and device_select('cuda:0'):
torch.cuda.device('cuda:0')
return
return True
print('Autoselecting GPU. Using most free memory.')
max_mem_free = 0
@ -177,23 +176,14 @@ def device_init(device_selection):
if best_device and device_select(best_device):
print(f'Setting {device} as active')
torch.cuda.device(device)
return
return True
if device_select(device_selection):
if device_selection != 'auto' and device_select(device_selection):
print(f'Setting {device_selection} as active')
torch.cuda.device(device_selection)
return
return True
# By default use current device.
print('Checking current GPU...')
device = f'cuda:{torch.cuda.current_device()}'
device_name = torch.cuda.get_device_name(device)
print(f'{device} detected: {device_name}')
if device_select(device):
return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
return False
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')

View File

@ -253,7 +253,11 @@ def thread_render(device):
global current_state, current_state_error, current_model_path, current_vae_path
from . import runtime
try:
runtime.device_init(device)
if not runtime.device_init(device):
weak_thread_data[threading.current_thread()] = {
'error': f'Could not start on the selected device: {device}'
}
return
except Exception as e:
print(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {