Fix the error where a device named 'None' would get assigned for incompatible GPUs

This commit is contained in:
cmdr2 2022-11-11 15:43:20 +05:30
parent b4f7d6bf25
commit c13bccc7ae
2 changed files with 13 additions and 13 deletions

View File

@ -69,31 +69,31 @@ def validate_device_id(device, allow_auto=False, log_prefix=''):
if not isinstance(device, str) or (device not in device_names and (len(device) <= len('cuda:') or device[:5] != 'cuda:' or not device[5:].isnumeric())): if not isinstance(device, str) or (device not in device_names and (len(device) <= len('cuda:') or device[:5] != 'cuda:' or not device[5:].isnumeric())):
raise EnvironmentError(f"{log_prefix}: device id should be {', '.join(device_names)}, or 'cuda:N' (where N is an integer index for the GPU). Got: {device}") raise EnvironmentError(f"{log_prefix}: device id should be {', '.join(device_names)}, or 'cuda:N' (where N is an integer index for the GPU). Got: {device}")
def device_would_fail(device): '''
validate_device_id(device, allow_auto=False, log_prefix='device_would_fail') Returns True/False, and prints any compatibility errors
'''
def is_device_compatible(device):
validate_device_id(device, allow_auto=False, log_prefix='is_device_compatible')
if device == 'cpu': return None if device == 'cpu': return True
# Returns None when no issues found, otherwise returns the detected error str.
# Memory check # Memory check
try: try:
mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) mem_total /= float(10**9)
if mem_total < 3.0: if mem_total < 3.0:
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion' print('GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion')
return False
except RuntimeError as e: except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings print(str(e))
return None return False
return True
def device_select(device): def device_select(device):
validate_device_id(device, allow_auto=False, log_prefix='device_select') validate_device_id(device, allow_auto=False, log_prefix='device_select')
if device == 'cpu': return True if device == 'cpu': return True
if not torch.cuda.is_available(): return False if not torch.cuda.is_available(): return False
failure_msg = device_would_fail(device) if not is_device_compatible(device):
if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'{device} could not be found. Remove this device from config.render_devices or use "auto".')
print(failure_msg)
return False return False
thread_data.device_name = torch.cuda.get_device_name(device) thread_data.device_name = torch.cuda.get_device_name(device)

View File

@ -366,7 +366,7 @@ def get_devices():
gpu_count = torch.cuda.device_count() gpu_count = torch.cuda.device_count()
for device in range(gpu_count): for device in range(gpu_count):
device = f'cuda:{device}' device = f'cuda:{device}'
if runtime.device_would_fail(device): if not runtime.is_device_compatible(device):
continue continue
devices['all'].update({device: torch.cuda.get_device_name(device)}) devices['all'].update({device: torch.cuda.get_device_name(device)})