Move the half precision bug check logic to sdkit

This commit is contained in:
cmdr2 2025-01-31 16:28:12 +05:30
parent d5a7c1bdf6
commit 35f752b36d

View File

@ -6,6 +6,8 @@ import traceback
import torch import torch
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit.utils import has_half_precision_bug
""" """
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16). Otherwise the models will load at half-precision (i.e. float16).
@ -160,20 +162,7 @@ def needs_to_force_full_precision(context):
return True return True
device_name = context.device_name.lower() device_name = context.device_name.lower()
return ( return has_half_precision_bug(device_name)
("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name)
and (
" 1660" in device_name
or " 1650" in device_name
or " 1630" in device_name
or " t400" in device_name
or " t550" in device_name
or " t600" in device_name
or " t1000" in device_name
or " t1200" in device_name
or " t2000" in device_name
)
) or ("tesla k40m" in device_name)
def get_max_vram_usage_level(device): def get_max_vram_usage_level(device):