mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-18 15:20:47 +02:00
Move the half precision bug check logic to sdkit
This commit is contained in:
parent
d5a7c1bdf6
commit
35f752b36d
@ -6,6 +6,8 @@ import traceback
|
||||
import torch
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit.utils import has_half_precision_bug
|
||||
|
||||
"""
|
||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||
Otherwise the models will load at half-precision (i.e. float16).
|
||||
@ -160,20 +162,7 @@ def needs_to_force_full_precision(context):
|
||||
return True
|
||||
|
||||
device_name = context.device_name.lower()
|
||||
return (
|
||||
("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name)
|
||||
and (
|
||||
" 1660" in device_name
|
||||
or " 1650" in device_name
|
||||
or " 1630" in device_name
|
||||
or " t400" in device_name
|
||||
or " t550" in device_name
|
||||
or " t600" in device_name
|
||||
or " t1000" in device_name
|
||||
or " t1200" in device_name
|
||||
or " t2000" in device_name
|
||||
)
|
||||
) or ("tesla k40m" in device_name)
|
||||
return has_half_precision_bug(device_name)
|
||||
|
||||
|
||||
def get_max_vram_usage_level(device):
|
||||
|
Loading…
Reference in New Issue
Block a user