Remove hardcoded references to torch.cuda; Use torchruntime and sdkit's device utilities instead

This commit is contained in:
cmdr2 2025-02-06 13:06:51 +05:30
parent 964aef6bc3
commit 9e21d681a0
7 changed files with 83 additions and 177 deletions

View File

@ -74,6 +74,7 @@ modules_to_check = {
"onnxruntime": "1.19.2",
"huggingface-hub": "0.21.4",
"wandb": "0.13.7",
"torchruntime": "1.8.0",
}
modules_to_log = ["torch", "torchvision", "sdkit", "stable-diffusion-sdkit", "diffusers"]
@ -169,10 +170,10 @@ def update_modules():
# if sdkit is 2.0.15.x (or lower), then diffusers should be restricted to 0.21.4 (see below for the reason)
# otherwise use the current sdkit version (with the corresponding diffusers version)
expected_sdkit_version_str = "2.0.22.3"
expected_sdkit_version_str = "2.0.22.4"
expected_diffusers_version_str = "0.28.2"
legacy_sdkit_version_str = "2.0.15.12"
legacy_sdkit_version_str = "2.0.15.13"
legacy_diffusers_version_str = "0.21.4"
sdkit_version_str = version("sdkit")

View File

@ -54,8 +54,7 @@ OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
"render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
"render_devices": "auto",
"update_branch": "main",
"ui": {
"open_browser_on_start": True,

View File

@ -6,7 +6,14 @@ import traceback
import torch
from easydiffusion.utils import log
from sdkit.utils import has_half_precision_bug
from torchruntime.utils import (
get_installed_torch_platform,
get_device,
get_device_count,
get_device_name,
SUPPORTED_BACKENDS,
)
from sdkit.utils import mem_get_info, is_cpu_device, has_half_precision_bug
"""
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
@ -24,33 +31,15 @@ mem_free_threshold = 0
def get_device_delta(render_devices, active_devices):
"""
render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...]
active_devices: ['cpu', 'mps', 'cuda:N'...]
render_devices: 'auto' or backends listed in `torchruntime.utils.SUPPORTED_BACKENDS`
active_devices: [backends listed in `torchruntime.utils.SUPPORTED_BACKENDS`]
"""
if render_devices in ("cpu", "auto", "mps"):
render_devices = [render_devices]
elif render_devices is not None:
if isinstance(render_devices, str):
render_devices = [render_devices]
if isinstance(render_devices, list) and len(render_devices) > 0:
render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices))
if len(render_devices) == 0:
raise Exception(
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}'
)
render_devices = render_devices or "auto"
render_devices = [render_devices] if isinstance(render_devices, str) else render_devices
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
if len(render_devices) == 0:
raise Exception(
"Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
)
else:
raise Exception(
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
)
else:
render_devices = ["auto"]
# check for backend support
validate_render_devices(render_devices)
if "auto" in render_devices:
render_devices = auto_pick_devices(active_devices)
@ -66,47 +55,39 @@ def get_device_delta(render_devices, active_devices):
return devices_to_start, devices_to_stop
def is_mps_available():
return (
platform.system() == "Darwin"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
)
def validate_render_devices(render_devices):
supported_backends = ("auto",) + SUPPORTED_BACKENDS
unsupported_render_devices = [d for d in render_devices if not d.lower().startswith(supported_backends)]
def is_cuda_available():
return torch.cuda.is_available()
if unsupported_render_devices:
raise ValueError(
f"Invalid render devices in config: {unsupported_render_devices}. Valid render devices: {supported_backends}"
)
def auto_pick_devices(currently_active_devices):
global mem_free_threshold
if is_mps_available():
return ["mps"]
torch_platform_name = get_installed_torch_platform()[0]
if not is_cuda_available():
return ["cpu"]
device_count = torch.cuda.device_count()
if device_count == 1:
return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]
if is_cpu_device(torch_platform_name):
return [torch_platform_name]
device_count = get_device_count()
log.debug("Autoselecting GPU. Using most free memory.")
devices = []
for device in range(device_count):
device = f"cuda:{device}"
if not is_device_compatible(device):
continue
for device_id in range(device_count):
device_id = f"{torch_platform_name}:{device_id}" if device_count > 1 else torch_platform_name
device = get_device(device_id)
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free, mem_total = mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
device_name = get_device_name(device)
log.debug(
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
f"{device_id} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
)
devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})
devices.append({"device": device_id, "device_name": device_name, "mem_free": mem_free})
devices.sort(key=lambda x: x["mem_free"], reverse=True)
max_mem_free = devices[0]["mem_free"]
@ -119,56 +100,45 @@ def auto_pick_devices(currently_active_devices):
# always be very low (since their VRAM contains the model).
# These already-running devices probably aren't terrible, since they were picked in the past.
# Worst case, the user can restart the program and that'll get rid of them.
devices = list(
filter(
(lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices),
devices,
)
)
devices = list(map(lambda x: x["device"], devices))
devices = [
x["device"] for x in devices if x["mem_free"] >= mem_free_threshold or x["device"] in currently_active_devices
]
return devices
def device_init(context, device):
"""
This function assumes the 'device' has already been verified to be compatible.
`get_device_delta()` has already filtered out incompatible devices.
"""
def device_init(context, device_id):
context.device = device_id
validate_device_id(device, log_prefix="device_init")
if "cuda" not in device:
context.device = device
if is_cpu_device(context.torch_device):
context.device_name = get_processor_name()
context.half_precision = False
log.debug(f"Render device available as {context.device_name}")
return
else:
context.device_name = get_device_name(context.torch_device)
context.device_name = torch.cuda.get_device_name(device)
context.device = device
# Some graphics cards have bugs in their firmware that prevent image generation at half precision
if needs_to_force_full_precision(context.device_name):
log.warn(f"forcing full precision on this GPU, to avoid corrupted images. GPU: {context.device_name}")
context.half_precision = False
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
if needs_to_force_full_precision(context):
log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
# Apply force_full_precision now before models are loaded.
context.half_precision = False
log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
torch.cuda.device(device)
log.info(f'Setting {device_id} as active, with precision: {"half" if context.half_precision else "full"}')
def needs_to_force_full_precision(context):
def needs_to_force_full_precision(device_name):
if "FORCE_FULL_PRECISION" in os.environ:
return True
device_name = context.device_name.lower()
return has_half_precision_bug(device_name)
return has_half_precision_bug(device_name.lower())
def get_max_vram_usage_level(device):
if "cuda" in device:
_, mem_total = torch.cuda.mem_get_info(device)
else:
"Expects a torch.device as the argument"
if is_cpu_device(device):
return "high"
_, mem_total = mem_get_info(device)
if mem_total < 0.001: # probably a torch platform without a mem_get_info() implementation
return "high"
mem_total /= float(10**9)
@ -180,51 +150,6 @@ def get_max_vram_usage_level(device):
return "high"
def validate_device_id(device, log_prefix=""):
def is_valid():
if not isinstance(device, str):
return False
if device == "cpu" or device == "mps":
return True
if not device.startswith("cuda:") or not device[5:].isnumeric():
return False
return True
if not is_valid():
raise EnvironmentError(
f"{log_prefix}: device id should be 'cpu', 'mps', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
)
def is_device_compatible(device):
"""
Returns True/False, and prints any compatibility errors
"""
# static variable "history".
is_device_compatible.history = getattr(is_device_compatible, "history", {})
try:
validate_device_id(device, log_prefix="is_device_compatible")
except:
log.error(str(e))
return False
if device in ("cpu", "mps"):
return True
# Memory check
try:
_, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 1.9:
if is_device_compatible.history.get(device) == None:
log.warn(f"GPU {device} with less than 2 GB of VRAM is not compatible with Stable Diffusion")
is_device_compatible.history[device] = 1
return False
except RuntimeError as e:
log.error(str(e))
return False
return True
def get_processor_name():
try:
import subprocess

View File

@ -208,11 +208,13 @@ def set_app_config_internal(req: SetAppConfigRequest):
def update_render_devices_in_config(config, render_devices):
if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"):
raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}")
from easydiffusion.device_manager import validate_render_devices
if render_devices.startswith("cuda:"):
try:
render_devices = render_devices.split(",")
validate_render_devices(render_devices)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
config["render_devices"] = render_devices

View File

@ -21,6 +21,9 @@ from easydiffusion import device_manager
from easydiffusion.tasks import Task
from easydiffusion.utils import log
from torchruntime.utils import get_device_count, get_device, get_device_name, get_installed_torch_platform
from sdkit.utils import is_cpu_device, mem_get_info
THREAD_NAME_PREFIX = ""
ERR_LOCK_FAILED = " failed to acquire lock within timeout."
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
@ -339,34 +342,33 @@ def get_devices():
"active": {},
}
def get_device_info(device):
if device in ("cpu", "mps"):
def get_device_info(device_id):
if is_cpu_device(device_id):
return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device)
device = get_device(device_id)
mem_free, mem_total = mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
return {
"name": torch.cuda.get_device_name(device),
"name": get_device_name(device),
"mem_free": mem_free,
"mem_total": mem_total,
"max_vram_usage_level": device_manager.get_max_vram_usage_level(device),
}
# list the compatible devices
cuda_count = torch.cuda.device_count()
for device in range(cuda_count):
device = f"cuda:{device}"
if not device_manager.is_device_compatible(device):
continue
torch_platform_name = get_installed_torch_platform()[0]
device_count = get_device_count()
for device_id in range(device_count):
device_id = f"{torch_platform_name}:{device_id}" if device_count > 1 else torch_platform_name
devices["all"].update({device: get_device_info(device)})
devices["all"].update({device_id: get_device_info(device_id)})
if device_manager.is_mps_available():
devices["all"].update({"mps": get_device_info("mps")})
devices["all"].update({"cpu": get_device_info("cpu")})
if torch_platform_name != "cpu":
devices["all"].update({"cpu": get_device_info("cpu")})
# list the activated devices
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
@ -378,8 +380,8 @@ def get_devices():
weak_data = weak_thread_data.get(rthread)
if not weak_data or not "device" in weak_data or not "device_name" in weak_data:
continue
device = weak_data["device"]
devices["active"].update({device: get_device_info(device)})
device_id = weak_data["device"]
devices["active"].update({device_id: get_device_info(device_id)})
finally:
manager_lock.release()
@ -437,12 +439,6 @@ def start_render_thread(device):
def stop_render_thread(device):
try:
device_manager.validate_device_id(device, log_prefix="stop_render_thread")
except:
log.error(traceback.format_exc())
return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("stop_render_thread" + ERR_LOCK_FAILED)
log.info(f"Stopping Rendering Thread on device: {device}")

View File

@ -182,23 +182,6 @@ function loadSettings() {
}
})
CURRENTLY_LOADING_SETTINGS = false
} else if (localStorage.length < 2) {
// localStorage is too short for OldSettings
// So this is likely the first time Easy Diffusion is running.
// Initialize vram_usage_level based on the available VRAM
function initGPUProfile(event) {
if (
"detail" in event &&
"active" in event.detail &&
"cuda:0" in event.detail.active &&
event.detail.active["cuda:0"].mem_total < 4.5
) {
vramUsageLevelField.value = "low"
vramUsageLevelField.dispatchEvent(new Event("change"))
}
document.removeEventListener("system_info_update", initGPUProfile)
}
document.addEventListener("system_info_update", initGPUProfile)
} else {
CURRENTLY_LOADING_SETTINGS = true
tryLoadOldSettings()

View File

@ -658,7 +658,7 @@ function setDeviceInfo(devices) {
function ID_TO_TEXT(d) {
let info = devices.all[d]
if ("mem_free" in info && "mem_total" in info) {
if ("mem_free" in info && "mem_total" in info && info["mem_total"] > 0) {
return `${info.name} <small>(${d}) (${info.mem_free.toFixed(1)}Gb free / ${info.mem_total.toFixed(
1
)} Gb total)</small>`