Merge pull request #975 from michaelgallacher/beta

Add support for MPS when running on Apple silicon
This commit is contained in:
cmdr2 2023-03-08 10:00:10 +05:30 committed by GitHub
commit 737a81570a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 10 deletions

View File

@ -285,6 +285,7 @@ printf "\n\nEasy Diffusion installation complete, starting the server!\n\n"
SD_PATH=`pwd`
export PYTORCH_ENABLE_MPS_FALLBACK=1
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH"

View File

@ -1,4 +1,5 @@
import os
import platform
import torch
import traceback
import re
@ -66,6 +67,9 @@ def get_device_delta(render_devices, active_devices):
def auto_pick_devices(currently_active_devices):
global mem_free_threshold
if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built():
return ["mps"]
if not torch.cuda.is_available():
return ["cpu"]
@ -115,11 +119,11 @@ def device_init(context, device):
validate_device_id(device, log_prefix="device_init")
if device == "cpu":
context.device = "cpu"
if "cuda" not in device:
context.device = device
context.device_name = get_processor_name()
context.half_precision = False
log.debug(f"Render device CPU available as {context.device_name}")
log.debug(f"Render device available as {context.device_name}")
return
context.device_name = torch.cuda.get_device_name(device)
@ -134,8 +138,6 @@ def device_init(context, device):
log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
torch.cuda.device(device)
return
def needs_to_force_full_precision(context):
if "FORCE_FULL_PRECISION" in os.environ:
@ -174,7 +176,7 @@ def validate_device_id(device, log_prefix=""):
def is_valid():
if not isinstance(device, str):
return False
if device == "cpu":
if device == "cpu" or device == "mps":
return True
if not device.startswith("cuda:") or not device[5:].isnumeric():
return False
@ -182,7 +184,7 @@ def validate_device_id(device, log_prefix=""):
if not is_valid():
raise EnvironmentError(
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
)
@ -217,14 +219,14 @@ def is_device_compatible(device):
def get_processor_name():
try:
import platform, subprocess
import subprocess
if platform.system() == "Windows":
return platform.processor()
elif platform.system() == "Darwin":
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
command = "sysctl -n machdep.cpu.brand_string"
return subprocess.check_output(command).strip()
return subprocess.check_output(command, shell=True).decode().strip()
elif platform.system() == "Linux":
command = "cat /proc/cpuinfo"
all_info = subprocess.check_output(command, shell=True).decode().strip()

View File

@ -385,7 +385,7 @@ def get_devices():
}
def get_device_info(device):
if device == "cpu":
if "cuda" not in device:
return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device)