mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-20 01:48:21 +02:00
Auto-upgrade torch for NVIDIA 50xx series. Fix for #1918
This commit is contained in:
parent
bac23290dd
commit
a33737b991
@ -17,6 +17,8 @@ from pathlib import Path
|
|||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import re
|
import re
|
||||||
import torchruntime
|
import torchruntime
|
||||||
|
from torchruntime.device_db import get_gpus
|
||||||
|
from torchruntime.platform_detection import get_torch_platform
|
||||||
|
|
||||||
os_name = platform.system()
|
os_name = platform.system()
|
||||||
|
|
||||||
@ -63,6 +65,17 @@ def install(module_name: str, module_version: str, index_url=None):
|
|||||||
def update_modules():
|
def update_modules():
|
||||||
if version("torch") is None:
|
if version("torch") is None:
|
||||||
torchruntime.install(["torch", "torchvision"])
|
torchruntime.install(["torch", "torchvision"])
|
||||||
|
else:
|
||||||
|
torch_version_str = version("torch")
|
||||||
|
torch_version = version_str_to_tuple(torch_version_str)
|
||||||
|
print(f"Current torch version: {torch_version} ({torch_version_str})")
|
||||||
|
if torch_version < (2, 7):
|
||||||
|
gpu_infos = get_gpus()
|
||||||
|
torch_platform = get_torch_platform(gpu_infos)
|
||||||
|
print(f"Recommended torch platform: {torch_platform}")
|
||||||
|
if torch_platform == "nightly/cu128":
|
||||||
|
print("Upgrading torch to support NVIDIA 50xx series of graphics cards")
|
||||||
|
torchruntime.install(["--force", "--upgrade", "torch", "torchvision"])
|
||||||
|
|
||||||
for module_name, allowed_versions in modules_to_check.items():
|
for module_name, allowed_versions in modules_to_check.items():
|
||||||
if os.path.exists(f"src/{module_name}"):
|
if os.path.exists(f"src/{module_name}"):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user