Auto-upgrade torch for NVIDIA 50xx series. Fix for #1918

This commit is contained in:
cmdr2 2025-04-23 15:22:59 +05:30
parent bac23290dd
commit a33737b991

View File

@ -17,6 +17,8 @@ from pathlib import Path
from pprint import pprint
import re
import torchruntime
from torchruntime.device_db import get_gpus
from torchruntime.platform_detection import get_torch_platform
os_name = platform.system()
@ -63,6 +65,17 @@ def install(module_name: str, module_version: str, index_url=None):
def update_modules():
if version("torch") is None:
torchruntime.install(["torch", "torchvision"])
else:
torch_version_str = version("torch")
torch_version = version_str_to_tuple(torch_version_str)
print(f"Current torch version: {torch_version} ({torch_version_str})")
if torch_version < (2, 7):
gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
print(f"Recommended torch platform: {torch_platform}")
if torch_platform == "nightly/cu128":
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])
for module_name, allowed_versions in modules_to_check.items():
if os.path.exists(f"src/{module_name}"):