Upgrade the version of torch used for rocm for Navi 30+, and point to the broader torch URL

This commit is contained in:
cmdr2 2025-01-07 10:32:02 +05:30
parent 52aaef5e39
commit 5023619676

View File

@ -94,15 +94,15 @@ def install(module_name: str, module_version: str):
if module_name == "torch":
if "Navi 3" in amd_gpus:
# No AMD 7x00 support in rocm 5.2, needs 5.5+
module_version = "2.1.0+rocm5.5"
index_url = "https://download.pytorch.org/whl/rocm6.2"
module_version = "2.4.1+rocm6.1"
index_url = "https://download.pytorch.org/whl"
else:
module_version = "1.13.1+rocm5.2"
elif module_name == "torchvision":
if "Navi 3" in amd_gpus:
# No AMD 7x00 support in rocm 5.2, needs 5.5+
module_version = "0.16.0+rocm5.5"
index_url = "https://download.pytorch.org/whl/rocm6.2"
module_version = "0.20.0+rocm6.1"
index_url = "https://download.pytorch.org/whl"
else:
module_version = "0.14.1+rocm5.2"
elif os_name == "Darwin":
@ -286,7 +286,7 @@ def apply_torch_install_overrides(module_version: str):
module_version += "+cu117"
index_url = "https://download.pytorch.org/whl/cu117"
elif is_amd_on_linux():
index_url = "https://download.pytorch.org/whl/rocm6.2"
index_url = "https://download.pytorch.org/whl"
return module_version, index_url