Automatically detect whether NAVI1/2 or NAVI3 ROCm versions are needed

This commit is contained in:
JeLuF 2023-08-11 22:36:20 +02:00
parent e9f54c8bae
commit 53a79c1a81

View File

@ -14,6 +14,7 @@ import platform
import traceback
import shutil
from pathlib import Path
from pprint import pprint
os_name = platform.system()
@ -48,10 +49,21 @@ def install(module_name: str, module_version: str):
module_version, index_url = apply_torch_install_overrides(module_version)
if is_amd_on_linux(): # hack until AMD works properly on torch 2.0 (avoids black images on some cards)
amd_gpus = setup_amd_environment()
if module_name == "torch":
module_version = "1.13.1+rocm5.2"
if "Navi 3" in amd_gpus:
# No AMD 7x00 support in rocm 5.2, needs nightly 5.5. build
module_version = "2.1.0.dev-20230614+rocm5.5"
index_url = "https: //download.pytorch.org/whl/nightly/rocm5.5"
else:
module_version = "1.13.1+rocm5.2"
elif module_name == "torchvision":
module_version = "0.14.1+rocm5.2"
if "Navi 3" in amd_gpus:
# No AMD 7x00 support in rocm 5.2, needs nightly 5.5. build
module_version = "0.16.0.dev-20230614+rocm5.5"
index_url = "https: //download.pytorch.org/whl/nightly/rocm5.5"
else:
module_version = "0.14.1+rocm5.2"
elif os_name == "Darwin":
if module_name == "torch":
module_version = "1.13.1"
@ -167,7 +179,7 @@ Thanks!"""
def get_config():
# The config file is in the same directory as this script
config_directory = os.path.dirname(__file__)
config_yaml = os.path.join(config_directory, "..", "config.yaml")
config_yaml = os.path.join(config_directory, "config.yaml")
config_json = os.path.join(config_directory, "config.json")
config = None
@ -206,28 +218,35 @@ def setup_amd_environment():
gpus = list(filter(lambda x: ("amdgpu" in x), open("/proc/bus/pci/devices", "r").readlines()))
gpus = [ AMD_PCI_IDs[x.split("\t")[1].upper()] for x in gpus ]
i=0
supported_gpus=[]
for gpu in gpus:
print(f"Found AMD GPU {gpu}.")
if gpu.startswith("Navi 1"):
print("--- Applying Navi 1 settings")
os.environ["HSA_OVERRIDE_GFX_VERSION"]="10.3.0"
os.environ["FORCE_FULL_PRECISION"]="yes"
os.environ["HIP_VISIBLE_DEVICES"]=str(i)
supported_gpus.append("Navi 1")
elif gpu.startswith("Navi 2"):
print("--- Applying Navi 2 settings")
os.environ["HSA_OVERRIDE_GFX_VERSION"]="10.3.0"
os.environ["HIP_VISIBLE_DEVICES"]=str(i)
supported_gpus.append("Navi 2")
elif gpu.startswith("Navi 3"):
print("\n ---- TODO ---- \n")
print("--- Applying Navi 3 settings")
os.environ["HSA_OVERRIDE_GFX_VERSION"]="11.0.0"
os.environ["HIP_VISIBLE_DEVICES"]=str(i)
supported_gpus.append("Navi 3")
else:
print("\nThis GPU is probably not supported by ROCm\n")
print("--- This GPU is probably not supported by ROCm\n")
i+=1
return supported_gpus
def launch_uvicorn():
config = get_config()
print(config)
pprint(config)
with open("scripts/install_status.txt","a") as f:
f.write("sd_weights_downloaded\n")
@ -246,13 +265,16 @@ def launch_uvicorn():
bind_ip = "127.0.0.1"
if "net" in config:
print("Checking network settings")
if "listen_port" in config["net"]:
listen_port = config["net"]["listen_port"]
if "listen_to_network" in config["net"] and config["net"]["listen_to_network"] == "True":
print("Set listen port to ", listen_port)
if "listen_to_network" in config["net"] and config["net"]["listen_to_network"] == True:
if "bind_ip" in config["net"]:
bind_ip = config["net"]["bind_ip"]
else:
bind_ip = "0.0.0.0"
print("Set bind_ip to ", bind_ip)
os.chdir("stable-diffusion")
@ -269,6 +291,8 @@ def launch_uvicorn():
### Start
# This list would probably be a good candidate for an import, but since PYTHONPATH and other settings
# have not been initialized yet, I keep the list here for the moment -- JeLuF
AMD_PCI_IDs = {
"1002AC0C": "Theater 506A USB",
"1002AC0D": "Theater 506A USB",