Add support for MPS when running on Apple silicon

Changes:

* autodetect if MPS is available and the pytorch version has MPS support.
* change logic from "is the device CPU?" to "is the device not CUDA?".
* set PYTORCH_ENABLE_MPS_FALLBACK=1

Known issues:

* Some samplers (eg DDIM) will fail on MPS unless forced to CPU-only mode
This commit is contained in:
Michael Gallacher 2023-03-07 14:57:37 -07:00
parent 8acff43028
commit 11265c4034
3 changed files with 13 additions and 10 deletions

View File

@ -285,6 +285,7 @@ printf "\n\nEasy Diffusion installation complete, starting the server!\n\n"
SD_PATH=`pwd` SD_PATH=`pwd`
export PYTORCH_ENABLE_MPS_FALLBACK=1
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH" echo "PYTHONPATH=$PYTHONPATH"

View File

@ -1,4 +1,5 @@
import os import os
import platform
import torch import torch
import traceback import traceback
import re import re
@ -66,6 +67,9 @@ def get_device_delta(render_devices, active_devices):
def auto_pick_devices(currently_active_devices): def auto_pick_devices(currently_active_devices):
global mem_free_threshold global mem_free_threshold
if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built():
return ["mps"]
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return ["cpu"] return ["cpu"]
@ -115,11 +119,11 @@ def device_init(context, device):
validate_device_id(device, log_prefix="device_init") validate_device_id(device, log_prefix="device_init")
if device == "cpu": if "cuda" not in device:
context.device = "cpu" context.device = device
context.device_name = get_processor_name() context.device_name = get_processor_name()
context.half_precision = False context.half_precision = False
log.debug(f"Render device CPU available as {context.device_name}") log.debug(f"Render device available as {context.device_name}")
return return
context.device_name = torch.cuda.get_device_name(device) context.device_name = torch.cuda.get_device_name(device)
@ -134,8 +138,6 @@ def device_init(context, device):
log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}') log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
torch.cuda.device(device) torch.cuda.device(device)
return
def needs_to_force_full_precision(context): def needs_to_force_full_precision(context):
if "FORCE_FULL_PRECISION" in os.environ: if "FORCE_FULL_PRECISION" in os.environ:
@ -174,7 +176,7 @@ def validate_device_id(device, log_prefix=""):
def is_valid(): def is_valid():
if not isinstance(device, str): if not isinstance(device, str):
return False return False
if device == "cpu": if device == "cpu" or device == "mps":
return True return True
if not device.startswith("cuda:") or not device[5:].isnumeric(): if not device.startswith("cuda:") or not device[5:].isnumeric():
return False return False
@ -182,7 +184,7 @@ def validate_device_id(device, log_prefix=""):
if not is_valid(): if not is_valid():
raise EnvironmentError( raise EnvironmentError(
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}" f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
) )
@ -217,14 +219,14 @@ def is_device_compatible(device):
def get_processor_name(): def get_processor_name():
try: try:
import platform, subprocess import subprocess
if platform.system() == "Windows": if platform.system() == "Windows":
return platform.processor() return platform.processor()
elif platform.system() == "Darwin": elif platform.system() == "Darwin":
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin" os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
command = "sysctl -n machdep.cpu.brand_string" command = "sysctl -n machdep.cpu.brand_string"
return subprocess.check_output(command).strip() return subprocess.check_output(command, shell=True).decode().strip()
elif platform.system() == "Linux": elif platform.system() == "Linux":
command = "cat /proc/cpuinfo" command = "cat /proc/cpuinfo"
all_info = subprocess.check_output(command, shell=True).decode().strip() all_info = subprocess.check_output(command, shell=True).decode().strip()

View File

@ -385,7 +385,7 @@ def get_devices():
} }
def get_device_info(device): def get_device_info(device):
if device == "cpu": if "cuda" not in device:
return {"name": device_manager.get_processor_name()} return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free, mem_total = torch.cuda.mem_get_info(device)