From 2efef8043e235bd9b730cbb2ada47b175ddf01fe Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 6 Feb 2025 13:07:10 +0530 Subject: [PATCH] Remove hardcoded torch.cuda references from the webui backend code --- ui/easydiffusion/backends/webui/__init__.py | 64 +++++---------------- 1 file changed, 13 insertions(+), 51 deletions(-) diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index 68ef640b..65743710 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -10,6 +10,8 @@ import shutil from easydiffusion.app import ROOT_DIR, getConfig from easydiffusion.model_manager import get_model_dirs from easydiffusion.utils import log +from torchruntime.utils import get_device, get_device_name, get_installed_torch_platform +from sdkit.utils import is_cpu_device from . import impl from .impl import ( @@ -93,7 +95,9 @@ def install_backend(): # install cpu-only torch if the PC doesn't have a graphics card (for Windows and Linux). # this avoids WebUI installing a CUDA version and trying to activate it - if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card(): + + torch_platform_name = get_installed_torch_platform()[0] + if OS_NAME in ("Windows", "Linux") and is_cpu_device(torch_platform_name): run_in_conda(["python", "-m", "pip", "install", "torch", "torchvision"], cwd=WEBUI_DIR, env=env) @@ -296,7 +300,8 @@ def create_context(): context = local() # temp hack, throws an attribute not found error otherwise - context.device = "cuda:0" + context.torch_device = get_device(0) + context.device = f"{context.torch_device.type}:{context.torch_device.index}" context.half_precision = True context.vram_usage_level = None @@ -379,17 +384,16 @@ def get_env(): else: env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"] else: - import torch - from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available + from easydiffusion.device_manager import needs_to_force_full_precision + + torch_platform_name = get_installed_torch_platform()[0] vram_usage_level = config.get("vram_usage_level", "balanced") - if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available(): + if config.get("render_devices", "auto") == "cpu" or is_cpu_device(torch_platform_name): env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" else: - c = local() - c.device_name = torch.cuda.get_device_name() - - if needs_to_force_full_precision(c): + device = get_device(0) + if needs_to_force_full_precision(get_device_name(device)): env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full" if vram_usage_level == "low": @@ -407,48 +411,6 @@ def get_env(): return env -def has_discrete_graphics_card(): - system = OS_NAME - - if system == "Windows": - try: - env = dict(os.environ) - env["PATH"] += os.pathsep + "C:/Windows/System32/WindowsPowerShell/v1.0".replace("/", os.path.sep) - - # PowerShell command to get the names of graphics cards - command = [ - "powershell", - "-Command", - "(Get-WmiObject Win32_VideoController).Name", - ] - # Run the command and decode the output - output = subprocess.check_output(command, universal_newlines=True, stderr=subprocess.STDOUT, env=env) - # Filter for discrete graphics cards (NVIDIA, AMD, etc.) - discrete_gpus = ["NVIDIA", "AMD", "ATI"] - return any(gpu in output for gpu in discrete_gpus) - except subprocess.CalledProcessError: - return False - - elif system == "Linux": - try: - output = subprocess.check_output(["lspci"], stderr=subprocess.STDOUT) - # Check for discrete GPUs (NVIDIA, AMD) - discrete_gpus = ["NVIDIA", "AMD", "Advanced Micro Devices"] - return any(gpu in line for line in output.decode().splitlines() for gpu in discrete_gpus) - except subprocess.CalledProcessError: - return False - - elif system == "Darwin": # macOS - try: - output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], stderr=subprocess.STDOUT) - # Check for discrete GPU in the output - return "NVIDIA" in output.decode() or "AMD" in output.decode() - except subprocess.CalledProcessError: - return False - - return False - - # https://stackoverflow.com/a/25134985 def kill(proc_pid): process = psutil.Process(proc_pid)