mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-03 08:36:12 +02:00
Remove hardcoded torch.cuda references from the webui backend code
This commit is contained in:
parent
9e21d681a0
commit
2efef8043e
@ -10,6 +10,8 @@ import shutil
|
||||
from easydiffusion.app import ROOT_DIR, getConfig
|
||||
from easydiffusion.model_manager import get_model_dirs
|
||||
from easydiffusion.utils import log
|
||||
from torchruntime.utils import get_device, get_device_name, get_installed_torch_platform
|
||||
from sdkit.utils import is_cpu_device
|
||||
|
||||
from . import impl
|
||||
from .impl import (
|
||||
@ -93,7 +95,9 @@ def install_backend():
|
||||
|
||||
# install cpu-only torch if the PC doesn't have a graphics card (for Windows and Linux).
|
||||
# this avoids WebUI installing a CUDA version and trying to activate it
|
||||
if OS_NAME in ("Windows", "Linux") and not has_discrete_graphics_card():
|
||||
|
||||
torch_platform_name = get_installed_torch_platform()[0]
|
||||
if OS_NAME in ("Windows", "Linux") and is_cpu_device(torch_platform_name):
|
||||
run_in_conda(["python", "-m", "pip", "install", "torch", "torchvision"], cwd=WEBUI_DIR, env=env)
|
||||
|
||||
|
||||
@ -296,7 +300,8 @@ def create_context():
|
||||
context = local()
|
||||
|
||||
# temp hack, throws an attribute not found error otherwise
|
||||
context.device = "cuda:0"
|
||||
context.torch_device = get_device(0)
|
||||
context.device = f"{context.torch_device.type}:{context.torch_device.index}"
|
||||
context.half_precision = True
|
||||
context.vram_usage_level = None
|
||||
|
||||
@ -379,17 +384,16 @@ def get_env():
|
||||
else:
|
||||
env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"]
|
||||
else:
|
||||
import torch
|
||||
from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available
|
||||
from easydiffusion.device_manager import needs_to_force_full_precision
|
||||
|
||||
torch_platform_name = get_installed_torch_platform()[0]
|
||||
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available():
|
||||
if config.get("render_devices", "auto") == "cpu" or is_cpu_device(torch_platform_name):
|
||||
env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu"
|
||||
else:
|
||||
c = local()
|
||||
c.device_name = torch.cuda.get_device_name()
|
||||
|
||||
if needs_to_force_full_precision(c):
|
||||
device = get_device(0)
|
||||
if needs_to_force_full_precision(get_device_name(device)):
|
||||
env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
|
||||
|
||||
if vram_usage_level == "low":
|
||||
@ -407,48 +411,6 @@ def get_env():
|
||||
return env
|
||||
|
||||
|
||||
def has_discrete_graphics_card():
|
||||
system = OS_NAME
|
||||
|
||||
if system == "Windows":
|
||||
try:
|
||||
env = dict(os.environ)
|
||||
env["PATH"] += os.pathsep + "C:/Windows/System32/WindowsPowerShell/v1.0".replace("/", os.path.sep)
|
||||
|
||||
# PowerShell command to get the names of graphics cards
|
||||
command = [
|
||||
"powershell",
|
||||
"-Command",
|
||||
"(Get-WmiObject Win32_VideoController).Name",
|
||||
]
|
||||
# Run the command and decode the output
|
||||
output = subprocess.check_output(command, universal_newlines=True, stderr=subprocess.STDOUT, env=env)
|
||||
# Filter for discrete graphics cards (NVIDIA, AMD, etc.)
|
||||
discrete_gpus = ["NVIDIA", "AMD", "ATI"]
|
||||
return any(gpu in output for gpu in discrete_gpus)
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
elif system == "Linux":
|
||||
try:
|
||||
output = subprocess.check_output(["lspci"], stderr=subprocess.STDOUT)
|
||||
# Check for discrete GPUs (NVIDIA, AMD)
|
||||
discrete_gpus = ["NVIDIA", "AMD", "Advanced Micro Devices"]
|
||||
return any(gpu in line for line in output.decode().splitlines() for gpu in discrete_gpus)
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
elif system == "Darwin": # macOS
|
||||
try:
|
||||
output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"], stderr=subprocess.STDOUT)
|
||||
# Check for discrete GPU in the output
|
||||
return "NVIDIA" in output.decode() or "AMD" in output.decode()
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/25134985
|
||||
def kill(proc_pid):
|
||||
process = psutil.Process(proc_pid)
|
||||
|
Loading…
x
Reference in New Issue
Block a user