mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-11 21:18:28 +02:00
Fix the error where a device named 'None' would get assigned for incompatible GPUs
This commit is contained in:
parent
b4f7d6bf25
commit
c13bccc7ae
@ -69,31 +69,31 @@ def validate_device_id(device, allow_auto=False, log_prefix=''):
|
||||
if not isinstance(device, str) or (device not in device_names and (len(device) <= len('cuda:') or device[:5] != 'cuda:' or not device[5:].isnumeric())):
|
||||
raise EnvironmentError(f"{log_prefix}: device id should be {', '.join(device_names)}, or 'cuda:N' (where N is an integer index for the GPU). Got: {device}")
|
||||
|
||||
def device_would_fail(device):
|
||||
validate_device_id(device, allow_auto=False, log_prefix='device_would_fail')
|
||||
'''
|
||||
Returns True/False, and prints any compatibility errors
|
||||
'''
|
||||
def is_device_compatible(device):
|
||||
validate_device_id(device, allow_auto=False, log_prefix='is_device_compatible')
|
||||
|
||||
if device == 'cpu': return None
|
||||
# Returns None when no issues found, otherwise returns the detected error str.
|
||||
if device == 'cpu': return True
|
||||
# Memory check
|
||||
try:
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 3.0:
|
||||
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
|
||||
print('GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion')
|
||||
return False
|
||||
except RuntimeError as e:
|
||||
return str(e) # Return cuda errors from mem_get_info as strings
|
||||
return None
|
||||
print(str(e))
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_select(device):
|
||||
validate_device_id(device, allow_auto=False, log_prefix='device_select')
|
||||
|
||||
if device == 'cpu': return True
|
||||
if not torch.cuda.is_available(): return False
|
||||
failure_msg = device_would_fail(device)
|
||||
if failure_msg:
|
||||
if 'invalid device' in failure_msg:
|
||||
raise NameError(f'{device} could not be found. Remove this device from config.render_devices or use "auto".')
|
||||
print(failure_msg)
|
||||
if not is_device_compatible(device):
|
||||
return False
|
||||
|
||||
thread_data.device_name = torch.cuda.get_device_name(device)
|
||||
|
@ -366,7 +366,7 @@ def get_devices():
|
||||
gpu_count = torch.cuda.device_count()
|
||||
for device in range(gpu_count):
|
||||
device = f'cuda:{device}'
|
||||
if runtime.device_would_fail(device):
|
||||
if not runtime.is_device_compatible(device):
|
||||
continue
|
||||
|
||||
devices['all'].update({device: torch.cuda.get_device_name(device)})
|
||||
|
Loading…
Reference in New Issue
Block a user