Upgrade torch if using the newer sdkit versions

This commit is contained in:
cmdr2 2024-06-06 16:18:11 +05:30
parent de680dfd09
commit 05f0bfebba

View File

@ -141,7 +141,7 @@ def update_modules():
else: else:
sdkit_version = version_str_to_tuple(sdkit_version_str) sdkit_version = version_str_to_tuple(sdkit_version_str)
legacy_sdkit_version = version_str_to_tuple(legacy_sdkit_version_str) legacy_sdkit_version = version_str_to_tuple(legacy_sdkit_version_str)
# torch_version = version_str_to_tuple(version("torch")) torch_version = version_str_to_tuple(version("torch"))
if sdkit_version[:3] <= legacy_sdkit_version[:3]: # and torch_version < (0, 13): if sdkit_version[:3] <= legacy_sdkit_version[:3]: # and torch_version < (0, 13):
# stick to diffusers 0.21.4, since it preserves torch 0.11+ compatibility. # stick to diffusers 0.21.4, since it preserves torch 0.11+ compatibility.
@ -152,6 +152,12 @@ def update_modules():
install_pkg_if_necessary("sdkit", legacy_sdkit_version_str) install_pkg_if_necessary("sdkit", legacy_sdkit_version_str)
install_pkg_if_necessary("diffusers", legacy_diffusers_version_str) install_pkg_if_necessary("diffusers", legacy_diffusers_version_str)
else: else:
if torch_version < (0, 13):
# install the gpu-compatible torch (if necessary), instead of the default CPU-only one
# from the diffusers dependency chain
install("torch", modules_to_check["torch"][-1])
install("torchvision", modules_to_check["torchvision"][-1])
install_pkg_if_necessary("sdkit", expected_sdkit_version_str) install_pkg_if_necessary("sdkit", expected_sdkit_version_str)
install_pkg_if_necessary("diffusers", expected_diffusers_version_str) install_pkg_if_necessary("diffusers", expected_diffusers_version_str)