From e7dc41e2715c28f41928530cf6f9bb7ec91408f7 Mon Sep 17 00:00:00 2001 From: Diana <5275194+DianaNites@users.noreply.github.com> Date: Tue, 18 Apr 2023 02:32:39 -0700 Subject: [PATCH] Automatic AMD GPU detection on Linux (#1078) * Automatic AMD GPU detection on Linux Automatically detects AMD GPUs and installs the ROCm version of PyTorch instead of the cuda one A later improvement may be to detect the GPU ROCm version and handle GPUs that dont work on upstream ROCm, ether because they're too old and need a special patched version, or too new and need `HSA_OVERRIDE_GFX_VERSION=10.3.0` added, possibly check through `rocminfo`? * Address stdout suppression and download failure * If any NVIDIA GPU is found, always use it * Use /proc/bus/pci/devices to detect GPUs * Fix comparisons `-eq` and `-ne` only work for numbers * Add back -q --------- Co-authored-by: JeLuF --- scripts/on_sd_start.sh | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 53f2c3b0..bc4b1608 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -63,11 +63,30 @@ case "${OS_NAME}" in *) echo "Unknown OS: $OS_NAME! This script runs only on Linux or Mac" && exit esac +# Detect GPU types + +if grep -q amdgpu /proc/bus/pci/devices; then + echo AMD GPU detected + HAS_AMD=yes +fi + +if grep -q nvidia /proc/bus/pci/devices; then + echo NVidia GPU detected + HAS_NVIDIA=yes +fi + + + # install torch and torchvision if python ../scripts/check_modules.py torch torchvision; then # temp fix for installations that installed torch 2.0 by mistake if [ "$OS_NAME" == "linux" ]; then - python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -q + # Check for AMD and NVIDIA dGPUs, always preferring an NVIDIA GPU if available + if [ "$HAS_NVIDIA" != "yes" -a "$HAS_AMD" = "yes" ]; then + python -m pip install --upgrade torch torchvision --extra-index-url "https://download.pytorch.org/whl/rocm5.4.2" -q || fail "Installation of torch and torchvision for AMD failed" + else + python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" -q || fail "Installation of torch and torchvision for CUDA failed" + fi elif [ "$OS_NAME" == "macos" ]; then python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 -q fi @@ -80,11 +99,13 @@ else export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" if [ "$OS_NAME" == "linux" ]; then - if python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 ; then - echo "Installed." + # Check for AMD and NVIDIA dGPUs, always preferring an NVIDIA GPU if available + if [ "$HAS_NVIDIA" != "yes" -a "$HAS_AMD" = "yes" ]; then + python -m pip install --upgrade torch torchvision --extra-index-url "https://download.pytorch.org/whl/rocm5.4.2" || fail "Installation of torch and torchvision for ROCm failed" else - fail "torch install failed" + python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" || fail "Installation of torch and torchvision for CUDA failed" fi + echo "Installed." elif [ "$OS_NAME" == "macos" ]; then if python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 ; then echo "Installed." @@ -151,7 +172,7 @@ else if conda install -c conda-forge -y uvicorn fastapi ; then echo "Installed. Testing.." else - fail "'conda install uvicorn' failed" + fail "'conda install uvicorn' failed" fi if ! command -v uvicorn &> /dev/null; then @@ -181,7 +202,7 @@ else if [ -f "../models/stable-diffusion/sd-v1-4.ckpt" ]; then model_size=`filesize "../models/stable-diffusion/sd-v1-4.ckpt"` if [ ! "$model_size" == "4265380512" ]; then - fail "The downloaded model file was invalid! Bytes downloaded: $model_size" + fail "The downloaded model file was invalid! Bytes downloaded: $model_size" fi else fail "Error downloading the data files (weights) for Stable Diffusion"