diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 8941ef0e..0861febb 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -127,6 +127,16 @@ def update_modules(): if module_name in modules_to_log: print(f"{module_name}: {version(module_name)}") + # hotfix accelerate + accelerate_version = version("accelerate") + if accelerate_version is None: + install("accelerate", "0.23.0") + else: + accelerate_version = accelerate_version.split(".") + accelerate_version = tuple(map(int, accelerate_version)) + if accelerate_version < (0, 23): + install("accelerate", "0.23.0") + ### utilities