diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 89297d69..7b3eb458 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -59,7 +59,15 @@ def init(): continue allowed_versions, latest_version = get_allowed_versions(module_name, allowed_versions) - if version(module_name) not in allowed_versions: + + requires_install = False + if module_name in ("torch", "torchvision"): + if version(module_name) is None: # allow any torch version + requires_install = True + elif version(module_name) not in allowed_versions: + requires_install = True + + if requires_install: try: install(module_name, latest_version) except: