From a33737b991a4fa060ed2a873f1ab127bc1a1dcb9 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 23 Apr 2025 15:22:59 +0530 Subject: [PATCH] Auto-upgrade torch for NVIDIA 50xx series. Fix for #1918 --- scripts/check_modules.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 448ea4ad..035b1e54 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -17,6 +17,8 @@ from pathlib import Path from pprint import pprint import re import torchruntime +from torchruntime.device_db import get_gpus +from torchruntime.platform_detection import get_torch_platform os_name = platform.system() @@ -63,6 +65,17 @@ def install(module_name: str, module_version: str, index_url=None): def update_modules(): if version("torch") is None: torchruntime.install(["torch", "torchvision"]) + else: + torch_version_str = version("torch") + torch_version = version_str_to_tuple(torch_version_str) + print(f"Current torch version: {torch_version} ({torch_version_str})") + if torch_version < (2, 7): + gpu_infos = get_gpus() + torch_platform = get_torch_platform(gpu_infos) + print(f"Recommended torch platform: {torch_platform}") + if torch_platform == "nightly/cu128": + print("Upgrading torch to support NVIDIA 50xx series of graphics cards") + torchruntime.install(["--force", "--upgrade", "torch", "torchvision"]) for module_name, allowed_versions in modules_to_check.items(): if os.path.exists(f"src/{module_name}"):