diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 3b7ccd58..a2122757 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -69,8 +69,9 @@ def update_modules(): else: torch_version_str = version("torch") torch_version = version_str_to_tuple(torch_version_str) + is_cpu_torch = "+" not in torch_version_str print(f"Current torch version: {torch_version} ({torch_version_str})") - if torch_version < (2, 7): + if torch_version < (2, 7) or is_cpu_torch: gpu_infos = get_gpus() device_names = set(gpu.device_name for gpu in gpu_infos) if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names):