mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-12-27 01:19:05 +01:00
159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
"""
|
|
This script checks and installs the required modules.
|
|
|
|
This script runs inside the legacy "stable-diffusion" folder
|
|
|
|
TODO - Maybe replace the bulk of this script with a call to `pip install -f requirements.txt`, with
|
|
a custom index URL depending on the platform.
|
|
|
|
"""
|
|
|
|
import os
|
|
from importlib.metadata import version as pkg_version
|
|
import platform
|
|
import traceback
|
|
|
|
os_name = platform.system()
|
|
|
|
modules_to_check = {
|
|
"torch": ("1.11.0", "1.13.1", "2.0.0"),
|
|
"torchvision": ("0.12.0", "0.14.1", "0.15.1"),
|
|
"sdkit": "1.0.97",
|
|
"stable-diffusion-sdkit": "2.1.4",
|
|
"rich": "12.6.0",
|
|
"uvicorn": "0.19.0",
|
|
"fastapi": "0.85.1",
|
|
# "xformers": "0.0.16",
|
|
}
|
|
|
|
|
|
def version(module_name: str) -> str:
|
|
try:
|
|
return pkg_version(module_name)
|
|
except:
|
|
return None
|
|
|
|
|
|
def install(module_name: str, module_version: str):
|
|
if module_name == "xformers" and (os_name == "Darwin" or is_amd_on_linux()):
|
|
return
|
|
|
|
index_url = None
|
|
if module_name in ("torch", "torchvision"):
|
|
module_version, index_url = apply_torch_install_overrides(module_version)
|
|
|
|
if is_amd_on_linux(): # hack until AMD works properly on torch 2.0 (avoids black images on some cards)
|
|
if module_name == "torch":
|
|
module_version = "1.13.1+rocm5.2"
|
|
elif module_name == "torchvision":
|
|
module_version = "0.14.1+rocm5.2"
|
|
elif os_name == "Darwin":
|
|
if module_name == "torch":
|
|
module_version = "1.13.1"
|
|
elif module_name == "torchvision":
|
|
module_version = "0.14.1"
|
|
|
|
install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}"
|
|
if index_url:
|
|
install_cmd += f" --index-url {index_url}"
|
|
if module_name == "sdkit" and version("sdkit") is not None:
|
|
install_cmd += " -q"
|
|
|
|
print(">", install_cmd)
|
|
os.system(install_cmd)
|
|
|
|
|
|
def init():
|
|
for module_name, allowed_versions in modules_to_check.items():
|
|
if os.path.exists(f"../src/{module_name}"):
|
|
print(f"Skipping {module_name} update, since it's in developer/editable mode")
|
|
continue
|
|
|
|
allowed_versions, latest_version = get_allowed_versions(module_name, allowed_versions)
|
|
|
|
requires_install = False
|
|
if module_name in ("torch", "torchvision"):
|
|
if version(module_name) is None: # allow any torch version
|
|
requires_install = True
|
|
elif os_name == "Darwin" and ( # force mac to downgrade from torch 2.0
|
|
version("torch").startswith("2.") or version("torchvision").startswith("0.15.")
|
|
):
|
|
requires_install = True
|
|
elif version(module_name) not in allowed_versions:
|
|
requires_install = True
|
|
|
|
if requires_install:
|
|
try:
|
|
install(module_name, latest_version)
|
|
except:
|
|
traceback.print_exc()
|
|
fail(module_name)
|
|
|
|
print(f"{module_name}: {version(module_name)}")
|
|
|
|
|
|
### utilities
|
|
|
|
|
|
def get_allowed_versions(module_name: str, allowed_versions: tuple):
|
|
allowed_versions = (allowed_versions,) if isinstance(allowed_versions, str) else allowed_versions
|
|
latest_version = allowed_versions[-1]
|
|
|
|
if module_name in ("torch", "torchvision"):
|
|
allowed_versions = include_cuda_versions(allowed_versions)
|
|
|
|
return allowed_versions, latest_version
|
|
|
|
|
|
def apply_torch_install_overrides(module_version: str):
|
|
index_url = None
|
|
if os_name == "Windows":
|
|
module_version += "+cu117"
|
|
index_url = "https://download.pytorch.org/whl/cu117"
|
|
elif is_amd_on_linux():
|
|
index_url = "https://download.pytorch.org/whl/rocm5.2"
|
|
|
|
return module_version, index_url
|
|
|
|
|
|
def include_cuda_versions(module_versions: tuple) -> tuple:
|
|
"Adds CUDA-specific versions to the list of allowed version numbers"
|
|
|
|
allowed_versions = tuple(module_versions)
|
|
allowed_versions += tuple(f"{v}+cu116" for v in module_versions)
|
|
allowed_versions += tuple(f"{v}+cu117" for v in module_versions)
|
|
allowed_versions += tuple(f"{v}+rocm5.2" for v in module_versions)
|
|
allowed_versions += tuple(f"{v}+rocm5.4.2" for v in module_versions)
|
|
|
|
return allowed_versions
|
|
|
|
|
|
def is_amd_on_linux():
|
|
if os_name == "Linux":
|
|
try:
|
|
with open("/proc/bus/pci/devices", "r") as f:
|
|
device_info = f.read()
|
|
if "amdgpu" in device_info and "nvidia" not in device_info:
|
|
return True
|
|
except:
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
def fail(module_name):
|
|
print(
|
|
f"""Error installing {module_name}. Sorry about that, please try to:
|
|
1. Run this installer again.
|
|
2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting
|
|
3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB
|
|
4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues
|
|
Thanks!"""
|
|
)
|
|
exit(1)
|
|
|
|
|
|
### start
|
|
|
|
init()
|