mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-10 11:56:50 +02:00
Use pytorch 2.7 (with cuda 12.8) on new installations with NVIDIA gpus
This commit is contained in:
parent
0e3b6a8609
commit
1ed27e63f9
@ -18,7 +18,6 @@ from pprint import pprint
|
||||
import re
|
||||
import torchruntime
|
||||
from torchruntime.device_db import get_gpus
|
||||
from torchruntime.platform_detection import get_torch_platform
|
||||
|
||||
os_name = platform.system()
|
||||
|
||||
@ -37,11 +36,13 @@ modules_to_check = {
|
||||
"onnxruntime": "1.19.2",
|
||||
"huggingface-hub": "0.21.4",
|
||||
"wandb": "0.17.2",
|
||||
"torchruntime": "1.15.1",
|
||||
# "torchruntime": "1.16.2",
|
||||
"torchsde": "0.2.6",
|
||||
}
|
||||
modules_to_log = ["torchruntime", "torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
|
||||
|
||||
BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b")
|
||||
|
||||
|
||||
def version(module_name: str) -> str:
|
||||
try:
|
||||
@ -71,9 +72,8 @@ def update_modules():
|
||||
print(f"Current torch version: {torch_version} ({torch_version_str})")
|
||||
if torch_version < (2, 7):
|
||||
gpu_infos = get_gpus()
|
||||
torch_platform = get_torch_platform(gpu_infos)
|
||||
print(f"Recommended torch platform: {torch_platform}")
|
||||
if torch_platform == "nightly/cu128":
|
||||
device_names = set(gpu.device_name for gpu in gpu_infos)
|
||||
if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names):
|
||||
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
|
||||
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])
|
||||
|
||||
|
@ -72,7 +72,7 @@ call where python
|
||||
call python --version
|
||||
|
||||
@rem this is outside check_modules.py to ensure that the required version of torchruntime is present
|
||||
call python -m pip install -q "torchruntime>=1.15.1"
|
||||
call python -m pip install -q "torchruntime>=1.16.2"
|
||||
|
||||
call python scripts\check_modules.py --launch-uvicorn
|
||||
pause
|
||||
|
@ -51,7 +51,7 @@ if [ -e "src" ]; then mv src src-old; fi
|
||||
if [ -e "ldm" ]; then mv ldm ldm-old; fi
|
||||
|
||||
# this is outside check_modules.py to ensure that the required version of torchruntime is present
|
||||
python -m pip install -q "torchruntime>=1.15.1"
|
||||
python -m pip install -q "torchruntime>=1.16.2"
|
||||
|
||||
cd ..
|
||||
# Download the required packages
|
||||
|
Loading…
x
Reference in New Issue
Block a user