forked from extern/easydiffusion
Merge pull request #975 from michaelgallacher/beta
Add support for MPS when running on Apple silicon
This commit is contained in:
commit
737a81570a
@ -285,6 +285,7 @@ printf "\n\nEasy Diffusion installation complete, starting the server!\n\n"
|
||||
|
||||
SD_PATH=`pwd`
|
||||
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
||||
echo "PYTHONPATH=$PYTHONPATH"
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import platform
|
||||
import torch
|
||||
import traceback
|
||||
import re
|
||||
@ -66,6 +67,9 @@ def get_device_delta(render_devices, active_devices):
|
||||
def auto_pick_devices(currently_active_devices):
|
||||
global mem_free_threshold
|
||||
|
||||
if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||
return ["mps"]
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return ["cpu"]
|
||||
|
||||
@ -115,11 +119,11 @@ def device_init(context, device):
|
||||
|
||||
validate_device_id(device, log_prefix="device_init")
|
||||
|
||||
if device == "cpu":
|
||||
context.device = "cpu"
|
||||
if "cuda" not in device:
|
||||
context.device = device
|
||||
context.device_name = get_processor_name()
|
||||
context.half_precision = False
|
||||
log.debug(f"Render device CPU available as {context.device_name}")
|
||||
log.debug(f"Render device available as {context.device_name}")
|
||||
return
|
||||
|
||||
context.device_name = torch.cuda.get_device_name(device)
|
||||
@ -134,8 +138,6 @@ def device_init(context, device):
|
||||
log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
|
||||
torch.cuda.device(device)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def needs_to_force_full_precision(context):
|
||||
if "FORCE_FULL_PRECISION" in os.environ:
|
||||
@ -174,7 +176,7 @@ def validate_device_id(device, log_prefix=""):
|
||||
def is_valid():
|
||||
if not isinstance(device, str):
|
||||
return False
|
||||
if device == "cpu":
|
||||
if device == "cpu" or device == "mps":
|
||||
return True
|
||||
if not device.startswith("cuda:") or not device[5:].isnumeric():
|
||||
return False
|
||||
@ -182,7 +184,7 @@ def validate_device_id(device, log_prefix=""):
|
||||
|
||||
if not is_valid():
|
||||
raise EnvironmentError(
|
||||
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
|
||||
f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
|
||||
)
|
||||
|
||||
|
||||
@ -217,14 +219,14 @@ def is_device_compatible(device):
|
||||
|
||||
def get_processor_name():
|
||||
try:
|
||||
import platform, subprocess
|
||||
import subprocess
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return platform.processor()
|
||||
elif platform.system() == "Darwin":
|
||||
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
|
||||
command = "sysctl -n machdep.cpu.brand_string"
|
||||
return subprocess.check_output(command).strip()
|
||||
return subprocess.check_output(command, shell=True).decode().strip()
|
||||
elif platform.system() == "Linux":
|
||||
command = "cat /proc/cpuinfo"
|
||||
all_info = subprocess.check_output(command, shell=True).decode().strip()
|
||||
|
@ -385,7 +385,7 @@ def get_devices():
|
||||
}
|
||||
|
||||
def get_device_info(device):
|
||||
if device == "cpu":
|
||||
if "cuda" not in device:
|
||||
return {"name": device_manager.get_processor_name()}
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
|
Loading…
Reference in New Issue
Block a user