Fix the error where a device named 'None' would get assigned for incompatible GPUs

This commit is contained in:
cmdr2 2022-11-11 15:43:20 +05:30
parent b4f7d6bf25
commit c13bccc7ae
2 changed files with 13 additions and 13 deletions

View File

@ -69,31 +69,31 @@ def validate_device_id(device, allow_auto=False, log_prefix=''):
if not isinstance(device, str) or (device not in device_names and (len(device) <= len('cuda:') or device[:5] != 'cuda:' or not device[5:].isnumeric())):
raise EnvironmentError(f"{log_prefix}: device id should be {', '.join(device_names)}, or 'cuda:N' (where N is an integer index for the GPU). Got: {device}")
def device_would_fail(device):
validate_device_id(device, allow_auto=False, log_prefix='device_would_fail')
'''
Returns True/False, and prints any compatibility errors
'''
def is_device_compatible(device):
validate_device_id(device, allow_auto=False, log_prefix='is_device_compatible')
if device == 'cpu': return None
# Returns None when no issues found, otherwise returns the detected error str.
if device == 'cpu': return True
# Memory check
try:
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
print('GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion')
return False
except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings
return None
print(str(e))
return False
return True
def device_select(device):
validate_device_id(device, allow_auto=False, log_prefix='device_select')
if device == 'cpu': return True
if not torch.cuda.is_available(): return False
failure_msg = device_would_fail(device)
if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'{device} could not be found. Remove this device from config.render_devices or use "auto".')
print(failure_msg)
if not is_device_compatible(device):
return False
thread_data.device_name = torch.cuda.get_device_name(device)

View File

@ -366,7 +366,7 @@ def get_devices():
gpu_count = torch.cuda.device_count()
for device in range(gpu_count):
device = f'cuda:{device}'
if runtime.device_would_fail(device):
if not runtime.is_device_compatible(device):
continue
devices['all'].update({device: torch.cuda.get_device_name(device)})