diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 8ef43b09..6d39fcfc 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -47,6 +47,11 @@ def install(module_name: str, module_version: str): module_version = "1.13.1+rocm5.2" elif module_name == "torchvision": module_version = "0.14.1+rocm5.2" + elif os_name == "Darwin": + if module_name == "torch": + module_version = "1.13.1" + elif module_name == "torchvision": + module_version = "0.14.1" install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}" if index_url: @@ -70,6 +75,10 @@ def init(): if module_name in ("torch", "torchvision"): if version(module_name) is None: # allow any torch version requires_install = True + elif os_name == "Darwin" and ( # force mac to downgrade from torch 2.0 + version("torch").startswith("2.") or version("torchvision").startswith("0.15.") + ): + requires_install = True elif version(module_name) not in allowed_versions: requires_install = True