From 11265c40344f3c668d96a1d1e80abb5bb5e8ad14 Mon Sep 17 00:00:00 2001 From: Michael Gallacher Date: Tue, 7 Mar 2023 14:57:37 -0700 Subject: [PATCH] Add support for MPS when running on Apple silicon Changes: * autodetect if MPS is available and the pytorch version has MPS support. * change logic from "is the device CPU?" to "is the device not CUDA?". * set PYTORCH_ENABLE_MPS_FALLBACK=1 Known issues: * Some samplers (eg DDIM) will fail on MPS unless forced to CPU-only mode --- scripts/on_sd_start.sh | 1 + ui/easydiffusion/device_manager.py | 20 +++++++++++--------- ui/easydiffusion/task_manager.py | 2 +- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index a8ed26be..c547c9dd 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -285,6 +285,7 @@ printf "\n\nEasy Diffusion installation complete, starting the server!\n\n" SD_PATH=`pwd` +export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" echo "PYTHONPATH=$PYTHONPATH" diff --git a/ui/easydiffusion/device_manager.py b/ui/easydiffusion/device_manager.py index 4279cc46..8c486dd2 100644 --- a/ui/easydiffusion/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -1,4 +1,5 @@ import os +import platform import torch import traceback import re @@ -66,6 +67,9 @@ def get_device_delta(render_devices, active_devices): def auto_pick_devices(currently_active_devices): global mem_free_threshold + if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built(): + return ["mps"] + if not torch.cuda.is_available(): return ["cpu"] @@ -115,11 +119,11 @@ def device_init(context, device): validate_device_id(device, log_prefix="device_init") - if device == "cpu": - context.device = "cpu" + if "cuda" not in device: + context.device = device context.device_name = get_processor_name() context.half_precision = False - log.debug(f"Render device CPU available as {context.device_name}") + log.debug(f"Render device available as {context.device_name}") return context.device_name = torch.cuda.get_device_name(device) @@ -134,8 +138,6 @@ def device_init(context, device): log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}') torch.cuda.device(device) - return - def needs_to_force_full_precision(context): if "FORCE_FULL_PRECISION" in os.environ: @@ -174,7 +176,7 @@ def validate_device_id(device, log_prefix=""): def is_valid(): if not isinstance(device, str): return False - if device == "cpu": + if device == "cpu" or device == "mps": return True if not device.startswith("cuda:") or not device[5:].isnumeric(): return False @@ -182,7 +184,7 @@ def validate_device_id(device, log_prefix=""): if not is_valid(): raise EnvironmentError( - f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}" + f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}" ) @@ -217,14 +219,14 @@ def is_device_compatible(device): def get_processor_name(): try: - import platform, subprocess + import subprocess if platform.system() == "Windows": return platform.processor() elif platform.system() == "Darwin": os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin" command = "sysctl -n machdep.cpu.brand_string" - return subprocess.check_output(command).strip() + return subprocess.check_output(command, shell=True).decode().strip() elif platform.system() == "Linux": command = "cat /proc/cpuinfo" all_info = subprocess.check_output(command, shell=True).decode().strip() diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 28f84963..d2e112ac 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -385,7 +385,7 @@ def get_devices(): } def get_device_info(device): - if device == "cpu": + if "cuda" not in device: return {"name": device_manager.get_processor_name()} mem_free, mem_total = torch.cuda.mem_get_info(device)