Extend the list of supported torch, CUDA and ROCm versions

This commit is contained in:
cmdr2 2025-01-06 19:30:18 +05:30
parent 4a62d4e76e
commit 493035df16

View File

@ -20,8 +20,41 @@ import re
os_name = platform.system()
modules_to_check = {
"torch": ("1.11.0", "1.13.1", "2.0.0", "2.0.1"),
"torchvision": ("0.12.0", "0.14.1", "0.15.1", "0.15.2"),
"torch": ( # really need a better way to check the versions. kinda reinventing the wheel here
"1.11.0",
"1.13.1",
"2.0.0",
"2.0.1",
"2.1.1",
"2.1.2",
"2.2.0",
"2.2.1",
"2.2.2",
"2.3.0",
"2.3.1",
"2.4.0",
"2.4.1",
"2.5.0",
"2.5.1",
),
"torchvision": (
"0.12.0",
"0.14.1",
"0.15.1",
"0.15.2",
"0.16.0",
"0.16.1",
"0.16.2",
"0.17.0",
"0.17.1",
"0.17.2",
"0.18.0",
"0.18.1",
"0.19.0",
"0.19.1",
"0.20.0",
"0.20.1",
),
"setuptools": "69.5.1",
# "sdkit": "2.0.15.6", # checked later
# "diffusers": "0.21.4", # checked later
@ -264,8 +297,24 @@ def include_cuda_versions(module_versions: tuple) -> tuple:
allowed_versions = tuple(module_versions)
allowed_versions += tuple(f"{v}+cu116" for v in module_versions)
allowed_versions += tuple(f"{v}+cu117" for v in module_versions)
allowed_versions += tuple(f"{v}+cu118" for v in module_versions)
allowed_versions += tuple(f"{v}+cu121" for v in module_versions)
allowed_versions += tuple(f"{v}+cu124" for v in module_versions)
allowed_versions += tuple(f"{v}+cu126" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm5.2" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm5.3" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm5.4.2" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm5.5" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm5.6" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm5.7" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm6.0" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm6.1" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm6.1" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm6.2" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm6.2.4" for v in module_versions)
allowed_versions += tuple(f"{v}+rocm6.3" for v in module_versions)
# needs a better way
return allowed_versions