forked from extern/easydiffusion
Install PyTorch 2.0 by default, but allow existing PyTorch 1.13.1 installations to continue running; Unify and streamline the installation of dependencies
This commit is contained in:
@ -1,23 +1,111 @@
|
||||
"""
|
||||
This script checks if the given modules exist
|
||||
|
||||
E.g. python check_modules.py sdkit==1.0.3 sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan
|
||||
This script checks and installs the required modules.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pkgutil
|
||||
from importlib.metadata import version
|
||||
import os
|
||||
from importlib.metadata import version as pkg_version
|
||||
import platform
|
||||
|
||||
modules = sys.argv[1:]
|
||||
missing_modules = []
|
||||
for m in modules:
|
||||
m = m.split("==")
|
||||
module_name = m[0]
|
||||
module_version = m[1] if len(m) > 1 else None
|
||||
is_installed = pkgutil.find_loader(module_name) is not None
|
||||
if not is_installed:
|
||||
print("module", module_name, "not found")
|
||||
exit(1)
|
||||
elif module_version and version(module_name) != module_version:
|
||||
print("module version is different! expected: ", module_version, ", actual: ", version(module_name))
|
||||
exit(1)
|
||||
os_name = platform.system()
|
||||
|
||||
modules_to_check = {
|
||||
"torch": ("1.13.1", "2.0.0"),
|
||||
"torchvision": ("0.14.1", "0.15.1"),
|
||||
"sdkit": "1.0.72",
|
||||
"stable-diffusion-sdkit": "2.1.4",
|
||||
"rich": "12.6.0",
|
||||
"uvicorn": "0.19.0",
|
||||
"fastapi": "0.85.1",
|
||||
}
|
||||
|
||||
|
||||
def version(module_name: str) -> str:
|
||||
try:
|
||||
return pkg_version(module_name)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def install(module_name: str, module_version: str):
|
||||
index_url = None
|
||||
if module_name in ("torch", "torchvision"):
|
||||
module_version, index_url = apply_torch_install_overrides(module_version)
|
||||
|
||||
install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}"
|
||||
if index_url:
|
||||
install_cmd += f" --index-url {index_url}"
|
||||
if module_name == "sdkit":
|
||||
install_cmd += " -q"
|
||||
|
||||
print(">", install_cmd)
|
||||
os.system(install_cmd)
|
||||
|
||||
|
||||
def init():
|
||||
for module_name, allowed_versions in modules_to_check.items():
|
||||
if os.path.exists(f"../src/{module_name}"):
|
||||
print(f"Skipping {module_name} update, since it's in developer/editable mode")
|
||||
continue
|
||||
|
||||
allowed_versions, latest_version = get_allowed_versions(module_name, allowed_versions)
|
||||
if version(module_name) not in allowed_versions:
|
||||
try:
|
||||
install(module_name, latest_version)
|
||||
except:
|
||||
fail(module_name)
|
||||
|
||||
print(f"{module_name}: {version(module_name)}")
|
||||
|
||||
|
||||
### utilities
|
||||
|
||||
|
||||
def get_allowed_versions(module_name: str, allowed_versions: tuple[str]):
|
||||
allowed_versions = (allowed_versions,) if isinstance(allowed_versions, str) else allowed_versions
|
||||
latest_version = allowed_versions[-1]
|
||||
|
||||
if module_name in ("torch", "torchvision"):
|
||||
allowed_versions = include_cuda_versions(allowed_versions)
|
||||
|
||||
return allowed_versions, latest_version
|
||||
|
||||
|
||||
def apply_torch_install_overrides(module_version: str):
|
||||
index_url = None
|
||||
if os_name == "Windows":
|
||||
module_version += "+cu117"
|
||||
index_url = "https://download.pytorch.org/whl/cu117"
|
||||
elif os_name == "Linux":
|
||||
with open("/proc/bus/pci/devices", "r") as f:
|
||||
device_info = f.read()
|
||||
if "amdgpu" in device_info:
|
||||
index_url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
|
||||
return module_version, index_url
|
||||
|
||||
|
||||
def include_cuda_versions(module_versions: tuple) -> tuple:
|
||||
"Adds CUDA-specific versions to the list of allowed version numbers"
|
||||
|
||||
allowed_versions = tuple(module_versions)
|
||||
allowed_versions += tuple(f"{v}+cu116" for v in module_versions)
|
||||
allowed_versions += tuple(f"{v}+cu117" for v in module_versions)
|
||||
|
||||
return allowed_versions
|
||||
|
||||
|
||||
def fail(module_name):
|
||||
print(
|
||||
f"""Error installing {module_name}. Sorry about that, please try to:
|
||||
1. Run this installer again.
|
||||
2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting
|
||||
3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB
|
||||
4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues
|
||||
Thanks!"""
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
### start
|
||||
|
||||
init()
|
||||
|
Reference in New Issue
Block a user