mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-11 04:17:08 +02:00
Use pytorch 2.7 (with cuda 12.8) on new installations with NVIDIA gpus
This commit is contained in:
parent
0e3b6a8609
commit
1ed27e63f9
@ -18,7 +18,6 @@ from pprint import pprint
|
|||||||
import re
|
import re
|
||||||
import torchruntime
|
import torchruntime
|
||||||
from torchruntime.device_db import get_gpus
|
from torchruntime.device_db import get_gpus
|
||||||
from torchruntime.platform_detection import get_torch_platform
|
|
||||||
|
|
||||||
os_name = platform.system()
|
os_name = platform.system()
|
||||||
|
|
||||||
@ -37,11 +36,13 @@ modules_to_check = {
|
|||||||
"onnxruntime": "1.19.2",
|
"onnxruntime": "1.19.2",
|
||||||
"huggingface-hub": "0.21.4",
|
"huggingface-hub": "0.21.4",
|
||||||
"wandb": "0.17.2",
|
"wandb": "0.17.2",
|
||||||
"torchruntime": "1.15.1",
|
# "torchruntime": "1.16.2",
|
||||||
"torchsde": "0.2.6",
|
"torchsde": "0.2.6",
|
||||||
}
|
}
|
||||||
modules_to_log = ["torchruntime", "torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
|
modules_to_log = ["torchruntime", "torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
|
||||||
|
|
||||||
|
BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b")
|
||||||
|
|
||||||
|
|
||||||
def version(module_name: str) -> str:
|
def version(module_name: str) -> str:
|
||||||
try:
|
try:
|
||||||
@ -71,9 +72,8 @@ def update_modules():
|
|||||||
print(f"Current torch version: {torch_version} ({torch_version_str})")
|
print(f"Current torch version: {torch_version} ({torch_version_str})")
|
||||||
if torch_version < (2, 7):
|
if torch_version < (2, 7):
|
||||||
gpu_infos = get_gpus()
|
gpu_infos = get_gpus()
|
||||||
torch_platform = get_torch_platform(gpu_infos)
|
device_names = set(gpu.device_name for gpu in gpu_infos)
|
||||||
print(f"Recommended torch platform: {torch_platform}")
|
if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names):
|
||||||
if torch_platform == "nightly/cu128":
|
|
||||||
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
|
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
|
||||||
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])
|
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ call where python
|
|||||||
call python --version
|
call python --version
|
||||||
|
|
||||||
@rem this is outside check_modules.py to ensure that the required version of torchruntime is present
|
@rem this is outside check_modules.py to ensure that the required version of torchruntime is present
|
||||||
call python -m pip install -q "torchruntime>=1.15.1"
|
call python -m pip install -q "torchruntime>=1.16.2"
|
||||||
|
|
||||||
call python scripts\check_modules.py --launch-uvicorn
|
call python scripts\check_modules.py --launch-uvicorn
|
||||||
pause
|
pause
|
||||||
|
@ -51,7 +51,7 @@ if [ -e "src" ]; then mv src src-old; fi
|
|||||||
if [ -e "ldm" ]; then mv ldm ldm-old; fi
|
if [ -e "ldm" ]; then mv ldm ldm-old; fi
|
||||||
|
|
||||||
# this is outside check_modules.py to ensure that the required version of torchruntime is present
|
# this is outside check_modules.py to ensure that the required version of torchruntime is present
|
||||||
python -m pip install -q "torchruntime>=1.15.1"
|
python -m pip install -q "torchruntime>=1.16.2"
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
# Download the required packages
|
# Download the required packages
|
||||||
|
Loading…
x
Reference in New Issue
Block a user