mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-30 04:04:08 +01:00
Handle device init failures and record that as an error, if the GPU has less than 3 gb of VRAM
This commit is contained in:
parent
c13bccc7ae
commit
3fdd8d91e2
@ -144,22 +144,21 @@ def device_init(device_selection):
|
|||||||
thread_data.device = 'cpu'
|
thread_data.device = 'cpu'
|
||||||
thread_data.device_name = get_processor_name()
|
thread_data.device_name = get_processor_name()
|
||||||
print('Render device CPU available as', thread_data.device_name)
|
print('Render device CPU available as', thread_data.device_name)
|
||||||
return
|
return True
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
if device_selection == 'auto':
|
if device_selection == 'auto':
|
||||||
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
||||||
thread_data.device = 'cpu'
|
thread_data.device = 'cpu'
|
||||||
thread_data.device_name = get_processor_name()
|
thread_data.device_name = get_processor_name()
|
||||||
return
|
return True
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError(f'Could not find a compatible GPU for the requested device_selection: {device_selection}!')
|
raise EnvironmentError(f'Could not find a compatible GPU for the requested device_selection: {device_selection}!')
|
||||||
|
|
||||||
if device_selection == 'auto':
|
if device_selection == 'auto':
|
||||||
device_count = torch.cuda.device_count()
|
device_count = torch.cuda.device_count()
|
||||||
if device_count == 1:
|
if device_count == 1 and device_select('cuda:0'):
|
||||||
device_select('cuda:0')
|
|
||||||
torch.cuda.device('cuda:0')
|
torch.cuda.device('cuda:0')
|
||||||
return
|
return True
|
||||||
|
|
||||||
print('Autoselecting GPU. Using most free memory.')
|
print('Autoselecting GPU. Using most free memory.')
|
||||||
max_mem_free = 0
|
max_mem_free = 0
|
||||||
@ -177,23 +176,14 @@ def device_init(device_selection):
|
|||||||
if best_device and device_select(best_device):
|
if best_device and device_select(best_device):
|
||||||
print(f'Setting {device} as active')
|
print(f'Setting {device} as active')
|
||||||
torch.cuda.device(device)
|
torch.cuda.device(device)
|
||||||
return
|
return True
|
||||||
|
|
||||||
if device_select(device_selection):
|
if device_selection != 'auto' and device_select(device_selection):
|
||||||
print(f'Setting {device_selection} as active')
|
print(f'Setting {device_selection} as active')
|
||||||
torch.cuda.device(device_selection)
|
torch.cuda.device(device_selection)
|
||||||
return
|
return True
|
||||||
|
|
||||||
# By default use current device.
|
return False
|
||||||
print('Checking current GPU...')
|
|
||||||
device = f'cuda:{torch.cuda.current_device()}'
|
|
||||||
device_name = torch.cuda.get_device_name(device)
|
|
||||||
print(f'{device} detected: {device_name}')
|
|
||||||
if device_select(device):
|
|
||||||
return
|
|
||||||
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
|
||||||
thread_data.device = 'cpu'
|
|
||||||
thread_data.device_name = get_processor_name()
|
|
||||||
|
|
||||||
def load_model_ckpt():
|
def load_model_ckpt():
|
||||||
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||||
|
@ -253,7 +253,11 @@ def thread_render(device):
|
|||||||
global current_state, current_state_error, current_model_path, current_vae_path
|
global current_state, current_state_error, current_model_path, current_vae_path
|
||||||
from . import runtime
|
from . import runtime
|
||||||
try:
|
try:
|
||||||
runtime.device_init(device)
|
if not runtime.device_init(device):
|
||||||
|
weak_thread_data[threading.current_thread()] = {
|
||||||
|
'error': f'Could not start on the selected device: {device}'
|
||||||
|
}
|
||||||
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {
|
||||||
|
Loading…
Reference in New Issue
Block a user