Refactor force-fp32 to not crash on PCs without cuda

This commit is contained in:
cmdr2 2024-10-12 12:45:48 +05:30
parent 76c8e18fcf
commit 7f9394b621

View File

@ -371,23 +371,23 @@ def get_env():
else: else:
env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"] env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"]
else: else:
import torch
from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available
vram_usage_level = config.get("vram_usage_level", "balanced") vram_usage_level = config.get("vram_usage_level", "balanced")
if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card(): if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available():
env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu" env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu"
elif vram_usage_level == "low": else:
env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram" c = local()
elif vram_usage_level == "high": c.device_name = torch.cuda.get_device_name()
env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram"
# check and force full-precision on NVIDIA 16xx graphics cards if needs_to_force_full_precision(c):
import torch env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
from easydiffusion.device_manager import needs_to_force_full_precision
c = local() if vram_usage_level == "low":
c.device_name = torch.cuda.get_device_name() env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram"
elif vram_usage_level == "high":
if needs_to_force_full_precision(c): env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram"
env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
env = {} env = {}
for key, paths in env_entries.items(): for key, paths in env_entries.items():