diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 5f4191e2..9e98634f 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -112,9 +112,12 @@ def device_init(device_selection=None): thread_data.device = 'cpu' return if not torch.cuda.is_available(): - print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!') - thread_data.device = 'cpu' - return + if device_selection == 'auto' or device_selection == 'current': + print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!') + thread_data.device = 'cpu' + return + else: + raise EnvironmentError('torch.cuda is not available.') device_count = torch.cuda.device_count() if device_count <= 1 and device_selection == 'auto': device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.