Refactor force-fp32 to not crash on PCs without cuda

This commit is contained in:
cmdr2 2024-10-12 12:45:48 +05:30
parent 76c8e18fcf
commit 7f9394b621

View File

@ -371,23 +371,23 @@ def get_env():
else:
env_entries["TORCH_COMMAND"] = ["pip install torch==2.3.1 torchvision==0.18.1"]
else:
import torch
from easydiffusion.device_manager import needs_to_force_full_precision, is_cuda_available
vram_usage_level = config.get("vram_usage_level", "balanced")
if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card():
if config.get("render_devices", "auto") == "cpu" or not has_discrete_graphics_card() or not is_cuda_available():
env_entries["COMMANDLINE_ARGS"][0] += " --always-cpu"
elif vram_usage_level == "low":
env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram"
elif vram_usage_level == "high":
env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram"
else:
c = local()
c.device_name = torch.cuda.get_device_name()
# check and force full-precision on NVIDIA 16xx graphics cards
import torch
from easydiffusion.device_manager import needs_to_force_full_precision
if needs_to_force_full_precision(c):
env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
c = local()
c.device_name = torch.cuda.get_device_name()
if needs_to_force_full_precision(c):
env_entries["COMMANDLINE_ARGS"][0] += " --no-half --precision full"
if vram_usage_level == "low":
env_entries["COMMANDLINE_ARGS"][0] += " --always-low-vram"
elif vram_usage_level == "high":
env_entries["COMMANDLINE_ARGS"][0] += " --always-high-vram"
env = {}
for key, paths in env_entries.items():