mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-19 17:39:16 +02:00
Auto-upgrade torch for NVIDIA 50xx series. Fix for #1918
This commit is contained in:
parent
bac23290dd
commit
a33737b991
@ -17,6 +17,8 @@ from pathlib import Path
|
||||
from pprint import pprint
|
||||
import re
|
||||
import torchruntime
|
||||
from torchruntime.device_db import get_gpus
|
||||
from torchruntime.platform_detection import get_torch_platform
|
||||
|
||||
os_name = platform.system()
|
||||
|
||||
@ -63,6 +65,17 @@ def install(module_name: str, module_version: str, index_url=None):
|
||||
def update_modules():
|
||||
if version("torch") is None:
|
||||
torchruntime.install(["torch", "torchvision"])
|
||||
else:
|
||||
torch_version_str = version("torch")
|
||||
torch_version = version_str_to_tuple(torch_version_str)
|
||||
print(f"Current torch version: {torch_version} ({torch_version_str})")
|
||||
if torch_version < (2, 7):
|
||||
gpu_infos = get_gpus()
|
||||
torch_platform = get_torch_platform(gpu_infos)
|
||||
print(f"Recommended torch platform: {torch_platform}")
|
||||
if torch_platform == "nightly/cu128":
|
||||
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
|
||||
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])
|
||||
|
||||
for module_name, allowed_versions in modules_to_check.items():
|
||||
if os.path.exists(f"src/{module_name}"):
|
||||
|
Loading…
x
Reference in New Issue
Block a user