Use pytorch 2.7 (with cuda 12.8) on new installations with NVIDIA gpus

This commit is contained in:
cmdr2 2025-04-24 15:50:55 +05:30
parent 0e3b6a8609
commit 1ed27e63f9
3 changed files with 7 additions and 7 deletions

View File

@ -18,7 +18,6 @@ from pprint import pprint
import re
import torchruntime
from torchruntime.device_db import get_gpus
from torchruntime.platform_detection import get_torch_platform
os_name = platform.system()
@ -37,11 +36,13 @@ modules_to_check = {
"onnxruntime": "1.19.2",
"huggingface-hub": "0.21.4",
"wandb": "0.17.2",
"torchruntime": "1.15.1",
# "torchruntime": "1.16.2",
"torchsde": "0.2.6",
}
modules_to_log = ["torchruntime", "torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b")
def version(module_name: str) -> str:
try:
@ -71,9 +72,8 @@ def update_modules():
print(f"Current torch version: {torch_version} ({torch_version_str})")
if torch_version < (2, 7):
gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
print(f"Recommended torch platform: {torch_platform}")
if torch_platform == "nightly/cu128":
device_names = set(gpu.device_name for gpu in gpu_infos)
if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names):
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])

View File

@ -72,7 +72,7 @@ call where python
call python --version
@rem this is outside check_modules.py to ensure that the required version of torchruntime is present
call python -m pip install -q "torchruntime>=1.15.1"
call python -m pip install -q "torchruntime>=1.16.2"
call python scripts\check_modules.py --launch-uvicorn
pause

View File

@ -51,7 +51,7 @@ if [ -e "src" ]; then mv src src-old; fi
if [ -e "ldm" ]; then mv ldm ldm-old; fi
# this is outside check_modules.py to ensure that the required version of torchruntime is present
python -m pip install -q "torchruntime>=1.15.1"
python -m pip install -q "torchruntime>=1.16.2"
cd ..
# Download the required packages