forked from extern/easydiffusion
Restrict AMD cards on Linux to torch 1.13.1 and ROCm 5.2. Avoids black images on some AMD cards. Temp hack until AMD works properly on torch 2.0
This commit is contained in:
parent
1967299417
commit
eb16296873
@ -42,6 +42,12 @@ def install(module_name: str, module_version: str):
|
|||||||
if module_name in ("torch", "torchvision"):
|
if module_name in ("torch", "torchvision"):
|
||||||
module_version, index_url = apply_torch_install_overrides(module_version)
|
module_version, index_url = apply_torch_install_overrides(module_version)
|
||||||
|
|
||||||
|
if is_amd_on_linux(): # hack until AMD works properly on torch 2.0 (avoids black images on some cards)
|
||||||
|
if module_name == "torch":
|
||||||
|
module_version = "1.13.1+rocm5.2"
|
||||||
|
elif module_name == "torchvision":
|
||||||
|
module_version = "0.14.1+rocm5.2"
|
||||||
|
|
||||||
install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}"
|
install_cmd = f"python -m pip install --upgrade {module_name}=={module_version}"
|
||||||
if index_url:
|
if index_url:
|
||||||
install_cmd += f" --index-url {index_url}"
|
install_cmd += f" --index-url {index_url}"
|
||||||
@ -96,7 +102,7 @@ def apply_torch_install_overrides(module_version: str):
|
|||||||
module_version += "+cu117"
|
module_version += "+cu117"
|
||||||
index_url = "https://download.pytorch.org/whl/cu117"
|
index_url = "https://download.pytorch.org/whl/cu117"
|
||||||
elif is_amd_on_linux():
|
elif is_amd_on_linux():
|
||||||
index_url = "https://download.pytorch.org/whl/rocm5.4.2"
|
index_url = "https://download.pytorch.org/whl/rocm5.2"
|
||||||
|
|
||||||
return module_version, index_url
|
return module_version, index_url
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user