diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 43219bba..990bc462 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -18,7 +18,6 @@ from pprint import pprint import re import torchruntime from torchruntime.device_db import get_gpus -from torchruntime.platform_detection import get_torch_platform os_name = platform.system() @@ -37,11 +36,13 @@ modules_to_check = { "onnxruntime": "1.19.2", "huggingface-hub": "0.21.4", "wandb": "0.17.2", - "torchruntime": "1.15.1", + # "torchruntime": "1.16.2", "torchsde": "0.2.6", } modules_to_log = ["torchruntime", "torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"] +BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b") + def version(module_name: str) -> str: try: @@ -71,9 +72,8 @@ def update_modules(): print(f"Current torch version: {torch_version} ({torch_version_str})") if torch_version < (2, 7): gpu_infos = get_gpus() - torch_platform = get_torch_platform(gpu_infos) - print(f"Recommended torch platform: {torch_platform}") - if torch_platform == "nightly/cu128": + device_names = set(gpu.device_name for gpu in gpu_infos) + if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names): print("Upgrading torch to support NVIDIA 50xx series of graphics cards") torchruntime.install(["--force", "--upgrade", "torch", "torchvision"]) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 957801b4..82167600 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -72,7 +72,7 @@ call where python call python --version @rem this is outside check_modules.py to ensure that the required version of torchruntime is present -call python -m pip install -q "torchruntime>=1.15.1" +call python -m pip install -q "torchruntime>=1.16.2" call python scripts\check_modules.py --launch-uvicorn pause diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 7d756b10..18d20471 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -51,7 +51,7 @@ if [ -e "src" ]; then mv src src-old; fi if [ -e "ldm" ]; then mv ldm ldm-old; fi # this is outside check_modules.py to ensure that the required version of torchruntime is present -python -m pip install -q "torchruntime>=1.15.1" +python -m pip install -q "torchruntime>=1.16.2" cd .. # Download the required packages