From eb16296873baa3b1b10ae4f88c3c668d411a852f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 21 Apr 2023 19:08:51 +0530 Subject: [PATCH] Restrict AMD cards on Linux to torch 1.13.1 and ROCm 5.2. Avoids black images on some AMD cards. Temp hack until AMD works properly on torch 2.0 --- scripts/check_modules.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 775724b8..8ef43b09 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -42,6 +42,12 @@ def install(module_name: str, module_version: str): if module_name in ("torch", "torchvision"): module_version, index_url = apply_torch_install_overrides(module_version) + if is_amd_on_linux(): # hack until AMD works properly on torch 2.0 (avoids black images on some cards) + if module_name == "torch": + module_version = "1.13.1+rocm5.2" + elif module_name == "torchvision": + module_version = "0.14.1+rocm5.2" + install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}" if index_url: install_cmd += f" --index-url {index_url}" @@ -96,7 +102,7 @@ def apply_torch_install_overrides(module_version: str): module_version += "+cu117" index_url = "https://download.pytorch.org/whl/cu117" elif is_amd_on_linux(): - index_url = "https://download.pytorch.org/whl/rocm5.4.2" + index_url = "https://download.pytorch.org/whl/rocm5.2" return module_version, index_url