From 05f0bfebba9353d5bb260b210a0b69d903b3811b Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 6 Jun 2024 16:18:11 +0530 Subject: [PATCH] Upgrade torch if using the newer sdkit versions --- scripts/check_modules.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 4bf12f9a..c5a36520 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -141,7 +141,7 @@ def update_modules(): else: sdkit_version = version_str_to_tuple(sdkit_version_str) legacy_sdkit_version = version_str_to_tuple(legacy_sdkit_version_str) - # torch_version = version_str_to_tuple(version("torch")) + torch_version = version_str_to_tuple(version("torch")) if sdkit_version[:3] <= legacy_sdkit_version[:3]: # and torch_version < (0, 13): # stick to diffusers 0.21.4, since it preserves torch 0.11+ compatibility. @@ -152,6 +152,12 @@ def update_modules(): install_pkg_if_necessary("sdkit", legacy_sdkit_version_str) install_pkg_if_necessary("diffusers", legacy_diffusers_version_str) else: + if torch_version < (0, 13): + # install the gpu-compatible torch (if necessary), instead of the default CPU-only one + # from the diffusers dependency chain + install("torch", modules_to_check["torch"][-1]) + install("torchvision", modules_to_check["torchvision"][-1]) + install_pkg_if_necessary("sdkit", expected_sdkit_version_str) install_pkg_if_necessary("diffusers", expected_diffusers_version_str)