mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-18 23:30:46 +02:00
Move the half precision bug check logic to sdkit
This commit is contained in:
parent
d5a7c1bdf6
commit
35f752b36d
@ -6,6 +6,8 @@ import traceback
|
|||||||
import torch
|
import torch
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
|
from sdkit.utils import has_half_precision_bug
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||||
Otherwise the models will load at half-precision (i.e. float16).
|
Otherwise the models will load at half-precision (i.e. float16).
|
||||||
@ -160,20 +162,7 @@ def needs_to_force_full_precision(context):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
device_name = context.device_name.lower()
|
device_name = context.device_name.lower()
|
||||||
return (
|
return has_half_precision_bug(device_name)
|
||||||
("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name)
|
|
||||||
and (
|
|
||||||
" 1660" in device_name
|
|
||||||
or " 1650" in device_name
|
|
||||||
or " 1630" in device_name
|
|
||||||
or " t400" in device_name
|
|
||||||
or " t550" in device_name
|
|
||||||
or " t600" in device_name
|
|
||||||
or " t1000" in device_name
|
|
||||||
or " t1200" in device_name
|
|
||||||
or " t2000" in device_name
|
|
||||||
)
|
|
||||||
) or ("tesla k40m" in device_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_vram_usage_level(device):
|
def get_max_vram_usage_level(device):
|
||||||
|
Loading…
Reference in New Issue
Block a user