Remove hardcoded torch.cuda references from the webui backend code

This commit is contained in:
cmdr2 2025-02-06 13:07:10 +05:30
parent 9e21d681a0
commit 2efef8043e

View File

@ -10,6 +10,8 @@ import shutil
from easydiffusion.app import ROOT_DIR, getConfig
from easydiffusion.model_manager import get_model_dirs
from easydiffusion.utils import log
from torchruntime.utils import get_device, get_device_name, get_installed_torch_platform
from sdkit.utils import is_cpu_device
from . import impl
from .impl import (
@ -93,7 +95,9 @@ def install_backend():
# install cpu-only torch if the PC doesn't have a graphics card (for Windows and Linux).
# this avoids WebUI installing a CUDA version and trying to activate it
if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card():
torch_platform_name = get_installed_torch_platform()[0]
if OS_NAME in ("Windows", "Linux") and is_cpu_device(torch_platform_name):
run_in_conda(["python", "-m", "pip", "install", "torch", "torchvision"], cwd=WEBUI_DIR, env=env)
@ -296,7 +300,8 @@ def create_context():
context = local()
# temp hack, throws an attribute not found error otherwise
context.device = "cuda:0"
context.torch_device = get_device(0)
context.device = f"{context.torch_device.type}:{context.torch_device.index}"
context.half_precision = True
context.vram_usage_level = None
@ -379,17 +384,16 @@ def get_env():
else:
env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"]
else:
import torch
from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available
from easydiffusion.device_manager import needs_to_force_full_precision
torch_platform_name = get_installed_torch_platform()[0]
vram_usage_level = config.get("vram_usage_level", "balanced")
if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available():
if config.get("render_devices", "auto") == "cpu" or is_cpu_device(torch_platform_name):
env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu"
else:
c = local()
c.device_name = torch.cuda.get_device_name()
if needs_to_force_full_precision(c):
device = get_device(0)
if needs_to_force_full_precision(get_device_name(device)):
env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
if vram_usage_level == "low":
@ -407,48 +411,6 @@ def get_env():
return env
def has_discrete_graphics_card():
system = OS_NAME
if system == "Windows":
try:
env = dict(os.environ)
env["PATH"] += os.pathsep + "C:/Windows/System32/WindowsPowerShell/v1.0".replace("/", os.path.sep)
# PowerShell command to get the names of graphics cards
command = [
"powershell",
"-Command",
"(Get-WmiObject Win32_VideoController).Name",
]
# Run the command and decode the output
output = subprocess.check_output(command, universal_newlines=True, stderr=subprocess.STDOUT, env=env)
# Filter for discrete graphics cards (NVIDIA, AMD, etc.)
discrete_gpus = ["NVIDIA", "AMD", "ATI"]
return any(gpu in output for gpu in discrete_gpus)
except subprocess.CalledProcessError:
return False
elif system == "Linux":
try:
output = subprocess.check_output(["lspci"], stderr=subprocess.STDOUT)
# Check for discrete GPUs (NVIDIA, AMD)
discrete_gpus = ["NVIDIA", "AMD", "Advanced Micro Devices"]
return any(gpu in line for line in output.decode().splitlines() for gpu in discrete_gpus)
except subprocess.CalledProcessError:
return False
elif system == "Darwin": # macOS
try:
output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], stderr=subprocess.STDOUT)
# Check for discrete GPU in the output
return "NVIDIA" in output.decode() or "AMD" in output.decode()
except subprocess.CalledProcessError:
return False
return False
# https://stackoverflow.com/a/25134985
def kill(proc_pid):
process = psutil.Process(proc_pid)