forked from extern/easydiffusion
Automatic AMD GPU detection on Linux (#1078)
* Automatic AMD GPU detection on Linux Automatically detects AMD GPUs and installs the ROCm version of PyTorch instead of the cuda one A later improvement may be to detect the GPU ROCm version and handle GPUs that dont work on upstream ROCm, ether because they're too old and need a special patched version, or too new and need `HSA_OVERRIDE_GFX_VERSION=10.3.0` added, possibly check through `rocminfo`? * Address stdout suppression and download failure * If any NVIDIA GPU is found, always use it * Use /proc/bus/pci/devices to detect GPUs * Fix comparisons `-eq` and `-ne` only work for numbers * Add back -q --------- Co-authored-by: JeLuF <jf@mormo.org>
This commit is contained in:
parent
127ee68486
commit
e7dc41e271
@ -63,11 +63,30 @@ case "${OS_NAME}" in
|
|||||||
*) echo "Unknown OS: $OS_NAME! This script runs only on Linux or Mac" && exit
|
*) echo "Unknown OS: $OS_NAME! This script runs only on Linux or Mac" && exit
|
||||||
esac
|
esac
|
||||||
|
|
||||||
|
# Detect GPU types
|
||||||
|
|
||||||
|
if grep -q amdgpu /proc/bus/pci/devices; then
|
||||||
|
echo AMD GPU detected
|
||||||
|
HAS_AMD=yes
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q nvidia /proc/bus/pci/devices; then
|
||||||
|
echo NVidia GPU detected
|
||||||
|
HAS_NVIDIA=yes
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# install torch and torchvision
|
# install torch and torchvision
|
||||||
if python ../scripts/check_modules.py torch torchvision; then
|
if python ../scripts/check_modules.py torch torchvision; then
|
||||||
# temp fix for installations that installed torch 2.0 by mistake
|
# temp fix for installations that installed torch 2.0 by mistake
|
||||||
if [ "$OS_NAME" == "linux" ]; then
|
if [ "$OS_NAME" == "linux" ]; then
|
||||||
python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -q
|
# Check for AMD and NVIDIA dGPUs, always preferring an NVIDIA GPU if available
|
||||||
|
if [ "$HAS_NVIDIA" != "yes" -a "$HAS_AMD" = "yes" ]; then
|
||||||
|
python -m pip install --upgrade torch torchvision --extra-index-url "https://download.pytorch.org/whl/rocm5.4.2" -q || fail "Installation of torch and torchvision for AMD failed"
|
||||||
|
else
|
||||||
|
python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" -q || fail "Installation of torch and torchvision for CUDA failed"
|
||||||
|
fi
|
||||||
elif [ "$OS_NAME" == "macos" ]; then
|
elif [ "$OS_NAME" == "macos" ]; then
|
||||||
python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 -q
|
python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 -q
|
||||||
fi
|
fi
|
||||||
@ -80,11 +99,13 @@ else
|
|||||||
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
||||||
|
|
||||||
if [ "$OS_NAME" == "linux" ]; then
|
if [ "$OS_NAME" == "linux" ]; then
|
||||||
if python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 ; then
|
# Check for AMD and NVIDIA dGPUs, always preferring an NVIDIA GPU if available
|
||||||
echo "Installed."
|
if [ "$HAS_NVIDIA" != "yes" -a "$HAS_AMD" = "yes" ]; then
|
||||||
|
python -m pip install --upgrade torch torchvision --extra-index-url "https://download.pytorch.org/whl/rocm5.4.2" || fail "Installation of torch and torchvision for ROCm failed"
|
||||||
else
|
else
|
||||||
fail "torch install failed"
|
python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" || fail "Installation of torch and torchvision for CUDA failed"
|
||||||
fi
|
fi
|
||||||
|
echo "Installed."
|
||||||
elif [ "$OS_NAME" == "macos" ]; then
|
elif [ "$OS_NAME" == "macos" ]; then
|
||||||
if python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 ; then
|
if python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 ; then
|
||||||
echo "Installed."
|
echo "Installed."
|
||||||
|
Loading…
Reference in New Issue
Block a user