mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
Automatically detect whether NAVI1/2 or NAVI3 ROCm versions are needed
This commit is contained in:
parent
e9f54c8bae
commit
53a79c1a81
@ -14,6 +14,7 @@ import platform
|
||||
import traceback
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
os_name = platform.system()
|
||||
|
||||
@ -48,10 +49,21 @@ def install(module_name: str, module_version: str):
|
||||
module_version, index_url = apply_torch_install_overrides(module_version)
|
||||
|
||||
if is_amd_on_linux(): # hack until AMD works properly on torch 2.0 (avoids black images on some cards)
|
||||
amd_gpus = setup_amd_environment()
|
||||
if module_name == "torch":
|
||||
module_version = "1.13.1+rocm5.2"
|
||||
if "Navi 3" in amd_gpus:
|
||||
# No AMD 7x00 support in rocm 5.2, needs nightly 5.5. build
|
||||
module_version = "2.1.0.dev-20230614+rocm5.5"
|
||||
index_url = "https: //download.pytorch.org/whl/nightly/rocm5.5"
|
||||
else:
|
||||
module_version = "1.13.1+rocm5.2"
|
||||
elif module_name == "torchvision":
|
||||
module_version = "0.14.1+rocm5.2"
|
||||
if "Navi 3" in amd_gpus:
|
||||
# No AMD 7x00 support in rocm 5.2, needs nightly 5.5. build
|
||||
module_version = "0.16.0.dev-20230614+rocm5.5"
|
||||
index_url = "https: //download.pytorch.org/whl/nightly/rocm5.5"
|
||||
else:
|
||||
module_version = "0.14.1+rocm5.2"
|
||||
elif os_name == "Darwin":
|
||||
if module_name == "torch":
|
||||
module_version = "1.13.1"
|
||||
@ -167,7 +179,7 @@ Thanks!"""
|
||||
def get_config():
|
||||
# The config file is in the same directory as this script
|
||||
config_directory = os.path.dirname(__file__)
|
||||
config_yaml = os.path.join(config_directory, "..", "config.yaml")
|
||||
config_yaml = os.path.join(config_directory, "config.yaml")
|
||||
config_json = os.path.join(config_directory, "config.json")
|
||||
|
||||
config = None
|
||||
@ -206,28 +218,35 @@ def setup_amd_environment():
|
||||
gpus = list(filter(lambda x: ("amdgpu" in x), open("/proc/bus/pci/devices", "r").readlines()))
|
||||
gpus = [ AMD_PCI_IDs[x.split("\t")[1].upper()] for x in gpus ]
|
||||
i=0
|
||||
supported_gpus=[]
|
||||
for gpu in gpus:
|
||||
print(f"Found AMD GPU {gpu}.")
|
||||
if gpu.startswith("Navi 1"):
|
||||
print("--- Applying Navi 1 settings")
|
||||
os.environ["HSA_OVERRIDE_GFX_VERSION"]="10.3.0"
|
||||
os.environ["FORCE_FULL_PRECISION"]="yes"
|
||||
os.environ["HIP_VISIBLE_DEVICES"]=str(i)
|
||||
supported_gpus.append("Navi 1")
|
||||
elif gpu.startswith("Navi 2"):
|
||||
print("--- Applying Navi 2 settings")
|
||||
os.environ["HSA_OVERRIDE_GFX_VERSION"]="10.3.0"
|
||||
os.environ["HIP_VISIBLE_DEVICES"]=str(i)
|
||||
supported_gpus.append("Navi 2")
|
||||
elif gpu.startswith("Navi 3"):
|
||||
print("\n ---- TODO ---- \n")
|
||||
print("--- Applying Navi 3 settings")
|
||||
os.environ["HSA_OVERRIDE_GFX_VERSION"]="11.0.0"
|
||||
os.environ["HIP_VISIBLE_DEVICES"]=str(i)
|
||||
supported_gpus.append("Navi 3")
|
||||
else:
|
||||
print("\nThis GPU is probably not supported by ROCm\n")
|
||||
print("--- This GPU is probably not supported by ROCm\n")
|
||||
i+=1
|
||||
return supported_gpus
|
||||
|
||||
|
||||
def launch_uvicorn():
|
||||
config = get_config()
|
||||
|
||||
print(config)
|
||||
pprint(config)
|
||||
|
||||
with open("scripts/install_status.txt","a") as f:
|
||||
f.write("sd_weights_downloaded\n")
|
||||
@ -246,13 +265,16 @@ def launch_uvicorn():
|
||||
|
||||
bind_ip = "127.0.0.1"
|
||||
if "net" in config:
|
||||
print("Checking network settings")
|
||||
if "listen_port" in config["net"]:
|
||||
listen_port = config["net"]["listen_port"]
|
||||
if "listen_to_network" in config["net"] and config["net"]["listen_to_network"] == "True":
|
||||
print("Set listen port to ", listen_port)
|
||||
if "listen_to_network" in config["net"] and config["net"]["listen_to_network"] == True:
|
||||
if "bind_ip" in config["net"]:
|
||||
bind_ip = config["net"]["bind_ip"]
|
||||
else:
|
||||
bind_ip = "0.0.0.0"
|
||||
print("Set bind_ip to ", bind_ip)
|
||||
|
||||
os.chdir("stable-diffusion")
|
||||
|
||||
@ -269,6 +291,8 @@ def launch_uvicorn():
|
||||
|
||||
### Start
|
||||
|
||||
# This list would probably be a good candidate for an import, but since PYTHONPATH and other settings
|
||||
# have not been initialized yet, I keep the list here for the moment -- JeLuF
|
||||
AMD_PCI_IDs = {
|
||||
"1002AC0C": "Theater 506A USB",
|
||||
"1002AC0D": "Theater 506A USB",
|
||||
|
Loading…
Reference in New Issue
Block a user