forked from extern/easydiffusion
Automatic AMD GPU detection on Linux (#1078)
* Automatic AMD GPU detection on Linux Automatically detects AMD GPUs and installs the ROCm version of PyTorch instead of the cuda one A later improvement may be to detect the GPU ROCm version and handle GPUs that dont work on upstream ROCm, ether because they're too old and need a special patched version, or too new and need `HSA_OVERRIDE_GFX_VERSION=10.3.0` added, possibly check through `rocminfo`? * Address stdout suppression and download failure * If any NVIDIA GPU is found, always use it * Use /proc/bus/pci/devices to detect GPUs * Fix comparisons `-eq` and `-ne` only work for numbers * Add back -q --------- Co-authored-by: JeLuF <jf@mormo.org>
This commit is contained in:
parent
127ee68486
commit
e7dc41e271
@ -63,11 +63,30 @@ case "${OS_NAME}" in
|
||||
*) echo "Unknown OS: $OS_NAME! This script runs only on Linux or Mac" && exit
|
||||
esac
|
||||
|
||||
# Detect GPU types
|
||||
|
||||
if grep -q amdgpu /proc/bus/pci/devices; then
|
||||
echo AMD GPU detected
|
||||
HAS_AMD=yes
|
||||
fi
|
||||
|
||||
if grep -q nvidia /proc/bus/pci/devices; then
|
||||
echo NVidia GPU detected
|
||||
HAS_NVIDIA=yes
|
||||
fi
|
||||
|
||||
|
||||
|
||||
# install torch and torchvision
|
||||
if python ../scripts/check_modules.py torch torchvision; then
|
||||
# temp fix for installations that installed torch 2.0 by mistake
|
||||
if [ "$OS_NAME" == "linux" ]; then
|
||||
python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -q
|
||||
# Check for AMD and NVIDIA dGPUs, always preferring an NVIDIA GPU if available
|
||||
if [ "$HAS_NVIDIA" != "yes" -a "$HAS_AMD" = "yes" ]; then
|
||||
python -m pip install --upgrade torch torchvision --extra-index-url "https://download.pytorch.org/whl/rocm5.4.2" -q || fail "Installation of torch and torchvision for AMD failed"
|
||||
else
|
||||
python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" -q || fail "Installation of torch and torchvision for CUDA failed"
|
||||
fi
|
||||
elif [ "$OS_NAME" == "macos" ]; then
|
||||
python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 -q
|
||||
fi
|
||||
@ -80,11 +99,13 @@ else
|
||||
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
||||
|
||||
if [ "$OS_NAME" == "linux" ]; then
|
||||
if python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 ; then
|
||||
echo "Installed."
|
||||
# Check for AMD and NVIDIA dGPUs, always preferring an NVIDIA GPU if available
|
||||
if [ "$HAS_NVIDIA" != "yes" -a "$HAS_AMD" = "yes" ]; then
|
||||
python -m pip install --upgrade torch torchvision --extra-index-url "https://download.pytorch.org/whl/rocm5.4.2" || fail "Installation of torch and torchvision for ROCm failed"
|
||||
else
|
||||
fail "torch install failed"
|
||||
python -m pip install --upgrade torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116" || fail "Installation of torch and torchvision for CUDA failed"
|
||||
fi
|
||||
echo "Installed."
|
||||
elif [ "$OS_NAME" == "macos" ]; then
|
||||
if python -m pip install --upgrade torch==1.13.1 torchvision==0.14.1 ; then
|
||||
echo "Installed."
|
||||
@ -151,7 +172,7 @@ else
|
||||
if conda install -c conda-forge -y uvicorn fastapi ; then
|
||||
echo "Installed. Testing.."
|
||||
else
|
||||
fail "'conda install uvicorn' failed"
|
||||
fail "'conda install uvicorn' failed"
|
||||
fi
|
||||
|
||||
if ! command -v uvicorn &> /dev/null; then
|
||||
@ -181,7 +202,7 @@ else
|
||||
if [ -f "../models/stable-diffusion/sd-v1-4.ckpt" ]; then
|
||||
model_size=`filesize "../models/stable-diffusion/sd-v1-4.ckpt"`
|
||||
if [ ! "$model_size" == "4265380512" ]; then
|
||||
fail "The downloaded model file was invalid! Bytes downloaded: $model_size"
|
||||
fail "The downloaded model file was invalid! Bytes downloaded: $model_size"
|
||||
fi
|
||||
else
|
||||
fail "Error downloading the data files (weights) for Stable Diffusion"
|
||||
|
Loading…
Reference in New Issue
Block a user