diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 4bf12f9a..c5a36520 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -141,7 +141,7 @@ def update_modules(): else: sdkit_version = version_str_to_tuple(sdkit_version_str) legacy_sdkit_version = version_str_to_tuple(legacy_sdkit_version_str) - # torch_version = version_str_to_tuple(version("torch")) + torch_version = version_str_to_tuple(version("torch")) if sdkit_version[:3] <= legacy_sdkit_version[:3]: # and torch_version < (0, 13): # stick to diffusers 0.21.4, since it preserves torch 0.11+ compatibility. @@ -152,6 +152,12 @@ def update_modules(): install_pkg_if_necessary("sdkit", legacy_sdkit_version_str) install_pkg_if_necessary("diffusers", legacy_diffusers_version_str) else: + if torch_version < (0, 13): + # install the gpu-compatible torch (if necessary), instead of the default CPU-only one + # from the diffusers dependency chain + install("torch", modules_to_check["torch"][-1]) + install("torchvision", modules_to_check["torchvision"][-1]) + install_pkg_if_necessary("sdkit", expected_sdkit_version_str) install_pkg_if_necessary("diffusers", expected_diffusers_version_str)