mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-31 07:05:45 +02:00
Remove hardcoded references to torch.cuda; Use torchruntime and sdkit's device utilities instead
This commit is contained in:
parent
964aef6bc3
commit
9e21d681a0
@ -74,6 +74,7 @@ modules_to_check = {
|
||||
"onnxruntime": "1.19.2",
|
||||
"huggingface-hub": "0.21.4",
|
||||
"wandb": "0.13.7",
|
||||
"torchruntime": "1.8.0",
|
||||
}
|
||||
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
|
||||
|
||||
@ -169,10 +170,10 @@ def update_modules():
|
||||
# if sdkit is 2.0.15.x (or lower), then diffusers should be restricted to 0.21.4 (see below for the reason)
|
||||
# otherwise use the current sdkit version (with the corresponding diffusers version)
|
||||
|
||||
expected_sdkit_version_str = "2.0.22.3"
|
||||
expected_sdkit_version_str = "2.0.22.4"
|
||||
expected_diffusers_version_str = "0.28.2"
|
||||
|
||||
legacy_sdkit_version_str = "2.0.15.12"
|
||||
legacy_sdkit_version_str = "2.0.15.13"
|
||||
legacy_diffusers_version_str = "0.21.4"
|
||||
|
||||
sdkit_version_str = version("sdkit")
|
||||
|
@ -54,8 +54,7 @@ OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||
"render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
||||
"render_devices": "auto",
|
||||
"update_branch": "main",
|
||||
"ui": {
|
||||
"open_browser_on_start": True,
|
||||
|
@ -6,7 +6,14 @@ import traceback
|
||||
import torch
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit.utils import has_half_precision_bug
|
||||
from torchruntime.utils import (
|
||||
get_installed_torch_platform,
|
||||
get_device,
|
||||
get_device_count,
|
||||
get_device_name,
|
||||
SUPPORTED_BACKENDS,
|
||||
)
|
||||
from sdkit.utils import mem_get_info, is_cpu_device, has_half_precision_bug
|
||||
|
||||
"""
|
||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||
@ -24,33 +31,15 @@ mem_free_threshold = 0
|
||||
|
||||
def get_device_delta(render_devices, active_devices):
|
||||
"""
|
||||
render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...]
|
||||
active_devices: ['cpu', 'mps', 'cuda:N'...]
|
||||
render_devices: 'auto' or backends listed in `torchruntime.utils.SUPPORTED_BACKENDS`
|
||||
active_devices: [backends listed in `torchruntime.utils.SUPPORTED_BACKENDS`]
|
||||
"""
|
||||
|
||||
if render_devices in ("cpu", "auto", "mps"):
|
||||
render_devices = [render_devices]
|
||||
elif render_devices is not None:
|
||||
if isinstance(render_devices, str):
|
||||
render_devices = [render_devices]
|
||||
if isinstance(render_devices, list) and len(render_devices) > 0:
|
||||
render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices))
|
||||
if len(render_devices) == 0:
|
||||
raise Exception(
|
||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}'
|
||||
)
|
||||
render_devices = render_devices or "auto"
|
||||
render_devices = [render_devices] if isinstance(render_devices, str) else render_devices
|
||||
|
||||
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
|
||||
if len(render_devices) == 0:
|
||||
raise Exception(
|
||||
"Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||
)
|
||||
else:
|
||||
render_devices = ["auto"]
|
||||
# check for backend support
|
||||
validate_render_devices(render_devices)
|
||||
|
||||
if "auto" in render_devices:
|
||||
render_devices = auto_pick_devices(active_devices)
|
||||
@ -66,47 +55,39 @@ def get_device_delta(render_devices, active_devices):
|
||||
return devices_to_start, devices_to_stop
|
||||
|
||||
|
||||
def is_mps_available():
|
||||
return (
|
||||
platform.system() == "Darwin"
|
||||
and hasattr(torch.backends, "mps")
|
||||
and torch.backends.mps.is_available()
|
||||
and torch.backends.mps.is_built()
|
||||
)
|
||||
def validate_render_devices(render_devices):
|
||||
supported_backends = ("auto",) + SUPPORTED_BACKENDS
|
||||
unsupported_render_devices = [d for d in render_devices if not d.lower().startswith(supported_backends)]
|
||||
|
||||
|
||||
def is_cuda_available():
|
||||
return torch.cuda.is_available()
|
||||
if unsupported_render_devices:
|
||||
raise ValueError(
|
||||
f"Invalid render devices in config: {unsupported_render_devices}. Valid render devices: {supported_backends}"
|
||||
)
|
||||
|
||||
|
||||
def auto_pick_devices(currently_active_devices):
|
||||
global mem_free_threshold
|
||||
|
||||
if is_mps_available():
|
||||
return ["mps"]
|
||||
torch_platform_name = get_installed_torch_platform()[0]
|
||||
|
||||
if not is_cuda_available():
|
||||
return ["cpu"]
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count == 1:
|
||||
return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]
|
||||
if is_cpu_device(torch_platform_name):
|
||||
return [torch_platform_name]
|
||||
|
||||
device_count = get_device_count()
|
||||
log.debug("Autoselecting GPU. Using most free memory.")
|
||||
devices = []
|
||||
for device in range(device_count):
|
||||
device = f"cuda:{device}"
|
||||
if not is_device_compatible(device):
|
||||
continue
|
||||
for device_id in range(device_count):
|
||||
device_id = f"{torch_platform_name}:{device_id}" if device_count > 1 else torch_platform_name
|
||||
device = get_device(device_id)
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free, mem_total = mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
device_name = get_device_name(device)
|
||||
log.debug(
|
||||
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||
f"{device_id} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||
)
|
||||
devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})
|
||||
devices.append({"device": device_id, "device_name": device_name, "mem_free": mem_free})
|
||||
|
||||
devices.sort(key=lambda x: x["mem_free"], reverse=True)
|
||||
max_mem_free = devices[0]["mem_free"]
|
||||
@ -119,56 +100,45 @@ def auto_pick_devices(currently_active_devices):
|
||||
# always be very low (since their VRAM contains the model).
|
||||
# These already-running devices probably aren't terrible, since they were picked in the past.
|
||||
# Worst case, the user can restart the program and that'll get rid of them.
|
||||
devices = list(
|
||||
filter(
|
||||
(lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices),
|
||||
devices,
|
||||
)
|
||||
)
|
||||
devices = list(map(lambda x: x["device"], devices))
|
||||
devices = [
|
||||
x["device"] for x in devices if x["mem_free"] >= mem_free_threshold or x["device"] in currently_active_devices
|
||||
]
|
||||
return devices
|
||||
|
||||
|
||||
def device_init(context, device):
|
||||
"""
|
||||
This function assumes the 'device' has already been verified to be compatible.
|
||||
`get_device_delta()` has already filtered out incompatible devices.
|
||||
"""
|
||||
def device_init(context, device_id):
|
||||
context.device = device_id
|
||||
|
||||
validate_device_id(device, log_prefix="device_init")
|
||||
|
||||
if "cuda" not in device:
|
||||
context.device = device
|
||||
if is_cpu_device(context.torch_device):
|
||||
context.device_name = get_processor_name()
|
||||
context.half_precision = False
|
||||
log.debug(f"Render device available as {context.device_name}")
|
||||
return
|
||||
else:
|
||||
context.device_name = get_device_name(context.torch_device)
|
||||
|
||||
context.device_name = torch.cuda.get_device_name(device)
|
||||
context.device = device
|
||||
# Some graphics cards have bugs in their firmware that prevent image generation at half precision
|
||||
if needs_to_force_full_precision(context.device_name):
|
||||
log.warn(f"forcing full precision on this GPU, to avoid corrupted images. GPU: {context.device_name}")
|
||||
context.half_precision = False
|
||||
|
||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||
if needs_to_force_full_precision(context):
|
||||
log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
|
||||
# Apply force_full_precision now before models are loaded.
|
||||
context.half_precision = False
|
||||
|
||||
log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
|
||||
torch.cuda.device(device)
|
||||
log.info(f'Setting {device_id} as active, with precision: {"half" if context.half_precision else "full"}')
|
||||
|
||||
|
||||
def needs_to_force_full_precision(context):
|
||||
def needs_to_force_full_precision(device_name):
|
||||
if "FORCE_FULL_PRECISION" in os.environ:
|
||||
return True
|
||||
|
||||
device_name = context.device_name.lower()
|
||||
return has_half_precision_bug(device_name)
|
||||
return has_half_precision_bug(device_name.lower())
|
||||
|
||||
|
||||
def get_max_vram_usage_level(device):
|
||||
if "cuda" in device:
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
else:
|
||||
"Expects a torch.device as the argument"
|
||||
|
||||
if is_cpu_device(device):
|
||||
return "high"
|
||||
|
||||
_, mem_total = mem_get_info(device)
|
||||
|
||||
if mem_total < 0.001: # probably a torch platform without a mem_get_info() implementation
|
||||
return "high"
|
||||
|
||||
mem_total /= float(10**9)
|
||||
@ -180,51 +150,6 @@ def get_max_vram_usage_level(device):
|
||||
return "high"
|
||||
|
||||
|
||||
def validate_device_id(device, log_prefix=""):
|
||||
def is_valid():
|
||||
if not isinstance(device, str):
|
||||
return False
|
||||
if device == "cpu" or device == "mps":
|
||||
return True
|
||||
if not device.startswith("cuda:") or not device[5:].isnumeric():
|
||||
return False
|
||||
return True
|
||||
|
||||
if not is_valid():
|
||||
raise EnvironmentError(
|
||||
f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
|
||||
)
|
||||
|
||||
|
||||
def is_device_compatible(device):
|
||||
"""
|
||||
Returns True/False, and prints any compatibility errors
|
||||
"""
|
||||
# static variable "history".
|
||||
is_device_compatible.history = getattr(is_device_compatible, "history", {})
|
||||
try:
|
||||
validate_device_id(device, log_prefix="is_device_compatible")
|
||||
except:
|
||||
log.error(str(e))
|
||||
return False
|
||||
|
||||
if device in ("cpu", "mps"):
|
||||
return True
|
||||
# Memory check
|
||||
try:
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 1.9:
|
||||
if is_device_compatible.history.get(device) == None:
|
||||
log.warn(f"GPU {device} with less than 2 GB of VRAM is not compatible with Stable Diffusion")
|
||||
is_device_compatible.history[device] = 1
|
||||
return False
|
||||
except RuntimeError as e:
|
||||
log.error(str(e))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_processor_name():
|
||||
try:
|
||||
import subprocess
|
||||
|
@ -208,11 +208,13 @@ def set_app_config_internal(req: SetAppConfigRequest):
|
||||
|
||||
|
||||
def update_render_devices_in_config(config, render_devices):
|
||||
if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}")
|
||||
from easydiffusion.device_manager import validate_render_devices
|
||||
|
||||
if render_devices.startswith("cuda:"):
|
||||
try:
|
||||
render_devices = render_devices.split(",")
|
||||
validate_render_devices(render_devices)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
config["render_devices"] = render_devices
|
||||
|
||||
|
@ -21,6 +21,9 @@ from easydiffusion import device_manager
|
||||
from easydiffusion.tasks import Task
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from torchruntime.utils import get_device_count, get_device, get_device_name, get_installed_torch_platform
|
||||
from sdkit.utils import is_cpu_device, mem_get_info
|
||||
|
||||
THREAD_NAME_PREFIX = ""
|
||||
ERR_LOCK_FAILED = " failed to acquire lock within timeout."
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
@ -339,34 +342,33 @@ def get_devices():
|
||||
"active": {},
|
||||
}
|
||||
|
||||
def get_device_info(device):
|
||||
if device in ("cpu", "mps"):
|
||||
def get_device_info(device_id):
|
||||
if is_cpu_device(device_id):
|
||||
return {"name": device_manager.get_processor_name()}
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
device = get_device(device_id)
|
||||
|
||||
mem_free, mem_total = mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
|
||||
return {
|
||||
"name": torch.cuda.get_device_name(device),
|
||||
"name": get_device_name(device),
|
||||
"mem_free": mem_free,
|
||||
"mem_total": mem_total,
|
||||
"max_vram_usage_level": device_manager.get_max_vram_usage_level(device),
|
||||
}
|
||||
|
||||
# list the compatible devices
|
||||
cuda_count = torch.cuda.device_count()
|
||||
for device in range(cuda_count):
|
||||
device = f"cuda:{device}"
|
||||
if not device_manager.is_device_compatible(device):
|
||||
continue
|
||||
torch_platform_name = get_installed_torch_platform()[0]
|
||||
device_count = get_device_count()
|
||||
for device_id in range(device_count):
|
||||
device_id = f"{torch_platform_name}:{device_id}" if device_count > 1 else torch_platform_name
|
||||
|
||||
devices["all"].update({device: get_device_info(device)})
|
||||
devices["all"].update({device_id: get_device_info(device_id)})
|
||||
|
||||
if device_manager.is_mps_available():
|
||||
devices["all"].update({"mps": get_device_info("mps")})
|
||||
|
||||
devices["all"].update({"cpu": get_device_info("cpu")})
|
||||
if torch_platform_name != "cpu":
|
||||
devices["all"].update({"cpu": get_device_info("cpu")})
|
||||
|
||||
# list the activated devices
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
@ -378,8 +380,8 @@ def get_devices():
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if not weak_data or not "device" in weak_data or not "device_name" in weak_data:
|
||||
continue
|
||||
device = weak_data["device"]
|
||||
devices["active"].update({device: get_device_info(device)})
|
||||
device_id = weak_data["device"]
|
||||
devices["active"].update({device_id: get_device_info(device_id)})
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
@ -437,12 +439,6 @@ def start_render_thread(device):
|
||||
|
||||
|
||||
def stop_render_thread(device):
|
||||
try:
|
||||
device_manager.validate_device_id(device, log_prefix="stop_render_thread")
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("stop_render_thread" + ERR_LOCK_FAILED)
|
||||
log.info(f"Stopping Rendering Thread on device: {device}")
|
||||
|
@ -182,23 +182,6 @@ function loadSettings() {
|
||||
}
|
||||
})
|
||||
CURRENTLY_LOADING_SETTINGS = false
|
||||
} else if (localStorage.length < 2) {
|
||||
// localStorage is too short for OldSettings
|
||||
// So this is likely the first time Easy Diffusion is running.
|
||||
// Initialize vram_usage_level based on the available VRAM
|
||||
function initGPUProfile(event) {
|
||||
if (
|
||||
"detail" in event &&
|
||||
"active" in event.detail &&
|
||||
"cuda:0" in event.detail.active &&
|
||||
event.detail.active["cuda:0"].mem_total < 4.5
|
||||
) {
|
||||
vramUsageLevelField.value = "low"
|
||||
vramUsageLevelField.dispatchEvent(new Event("change"))
|
||||
}
|
||||
document.removeEventListener("system_info_update", initGPUProfile)
|
||||
}
|
||||
document.addEventListener("system_info_update", initGPUProfile)
|
||||
} else {
|
||||
CURRENTLY_LOADING_SETTINGS = true
|
||||
tryLoadOldSettings()
|
||||
|
@ -658,7 +658,7 @@ function setDeviceInfo(devices) {
|
||||
|
||||
function ID_TO_TEXT(d) {
|
||||
let info = devices.all[d]
|
||||
if ("mem_free" in info && "mem_total" in info) {
|
||||
if ("mem_free" in info && "mem_total" in info && info["mem_total"] > 0) {
|
||||
return `${info.name} <small>(${d}) (${info.mem_free.toFixed(1)}Gb free / ${info.mem_total.toFixed(
|
||||
1
|
||||
)} Gb total)</small>`
|
||||
|
Loading…
x
Reference in New Issue
Block a user