Restrict AMD cards on Linux to torch 1.13.1 and ROCm 5.2. Avoids black images on some AMD cards. Temp hack until AMD works properly on torch 2.0

This commit is contained in:
cmdr2 2023-04-21 19:08:51 +05:30
parent 1967299417
commit eb16296873

View File

@ -42,6 +42,12 @@ def install(module_name: str, module_version: str):
if module_name in ("torch", "torchvision"): if module_name in ("torch", "torchvision"):
module_version, index_url = apply_torch_install_overrides(module_version) module_version, index_url = apply_torch_install_overrides(module_version)
if is_amd_on_linux(): # hack until AMD works properly on torch 2.0 (avoids black images on some cards)
if module_name == "torch":
module_version = "1.13.1+rocm5.2"
elif module_name == "torchvision":
module_version = "0.14.1+rocm5.2"
install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}" install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}"
if index_url: if index_url:
install_cmd += f" --index-url {index_url}" install_cmd += f" --index-url {index_url}"
@ -96,7 +102,7 @@ def apply_torch_install_overrides(module_version: str):
module_version += "+cu117" module_version += "+cu117"
index_url = "https://download.pytorch.org/whl/cu117" index_url = "https://download.pytorch.org/whl/cu117"
elif is_amd_on_linux(): elif is_amd_on_linux():
index_url = "https://download.pytorch.org/whl/rocm5.4.2" index_url = "https://download.pytorch.org/whl/rocm5.2"
return module_version, index_url return module_version, index_url