2022-11-14 06:53:22 +01:00
import os
import torch
import traceback
import re
2022-12-09 17:00:18 +01:00
import logging
log = logging . getLogger ( )
2022-11-14 06:53:22 +01:00
2022-12-11 13:46:29 +01:00
'''
Set ` FORCE_FULL_PRECISION ` in the environment variables , or in ` config . bat ` / ` config . sh ` to set full precision ( i . e . float32 ) .
Otherwise the models will load at half - precision ( i . e . float16 ) .
Half - precision is fine most of the time . Full precision is only needed for working around GPU bugs ( like NVIDIA 16 xx GPUs ) .
'''
2022-11-14 07:56:21 +01:00
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
2022-11-14 06:53:22 +01:00
2022-11-14 16:43:24 +01:00
mem_free_threshold = 0
2022-11-14 06:53:22 +01:00
def get_device_delta ( render_devices , active_devices ) :
'''
render_devices : ' cpu ' , or ' auto ' or [ ' cuda:N ' . . . ]
active_devices : [ ' cpu ' , ' cuda:N ' . . . ]
'''
2022-11-15 04:14:03 +01:00
if render_devices in ( ' cpu ' , ' auto ' ) :
render_devices = [ render_devices ]
elif render_devices is not None :
if isinstance ( render_devices , str ) :
2022-11-14 06:53:22 +01:00
render_devices = [ render_devices ]
2022-11-15 04:14:03 +01:00
if isinstance ( render_devices , list ) and len ( render_devices ) > 0 :
2022-11-14 06:53:22 +01:00
render_devices = list ( filter ( lambda x : x . startswith ( ' cuda: ' ) , render_devices ) )
if len ( render_devices ) == 0 :
raise Exception ( ' Invalid render_devices value in config.json. Valid: { " render_devices " : [ " cuda:0 " , " cuda:1 " ...]}, or { " render_devices " : " cpu " } or { " render_devices " : " auto " } ' )
render_devices = list ( filter ( lambda x : is_device_compatible ( x ) , render_devices ) )
if len ( render_devices ) == 0 :
raise Exception ( ' Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion ' )
else :
raise Exception ( ' Invalid render_devices value in config.json. Valid: { " render_devices " : [ " cuda:0 " , " cuda:1 " ...]}, or { " render_devices " : " cpu " } or { " render_devices " : " auto " } ' )
else :
render_devices = [ ' auto ' ]
if ' auto ' in render_devices :
render_devices = auto_pick_devices ( active_devices )
if ' cpu ' in render_devices :
2022-12-09 17:00:18 +01:00
log . warn ( ' WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow! ' )
2022-11-14 06:53:22 +01:00
active_devices = set ( active_devices )
render_devices = set ( render_devices )
devices_to_start = render_devices - active_devices
devices_to_stop = active_devices - render_devices
return devices_to_start , devices_to_stop
def auto_pick_devices ( currently_active_devices ) :
2022-11-14 16:43:24 +01:00
global mem_free_threshold
2022-11-14 06:53:22 +01:00
if not torch . cuda . is_available ( ) : return [ ' cpu ' ]
device_count = torch . cuda . device_count ( )
if device_count == 1 :
return [ ' cuda:0 ' ] if is_device_compatible ( ' cuda:0 ' ) else [ ' cpu ' ]
2022-12-09 17:00:18 +01:00
log . debug ( ' Autoselecting GPU. Using most free memory. ' )
2022-11-14 06:53:22 +01:00
devices = [ ]
for device in range ( device_count ) :
device = f ' cuda: { device } '
if not is_device_compatible ( device ) :
continue
mem_free , mem_total = torch . cuda . mem_get_info ( device )
mem_free / = float ( 10 * * 9 )
mem_total / = float ( 10 * * 9 )
device_name = torch . cuda . get_device_name ( device )
2022-12-09 17:00:18 +01:00
log . debug ( f ' { device } detected: { device_name } - Memory (free/total): { round ( mem_free , 2 ) } Gb / { round ( mem_total , 2 ) } Gb ' )
2022-11-14 06:53:22 +01:00
devices . append ( { ' device ' : device , ' device_name ' : device_name , ' mem_free ' : mem_free } )
devices . sort ( key = lambda x : x [ ' mem_free ' ] , reverse = True )
2022-11-14 16:43:24 +01:00
max_mem_free = devices [ 0 ] [ ' mem_free ' ]
curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
mem_free_threshold = max ( curr_mem_free_threshold , mem_free_threshold )
2022-11-14 06:53:22 +01:00
# Auto-pick algorithm:
# 1. Pick the top 75 percentile of the GPUs, sorted by free_mem.
# 2. Also include already-running devices (GPU-only), otherwise their free_mem will
# always be very low (since their VRAM contains the model).
# These already-running devices probably aren't terrible, since they were picked in the past.
# Worst case, the user can restart the program and that'll get rid of them.
2022-11-14 16:43:24 +01:00
devices = list ( filter ( ( lambda x : x [ ' mem_free ' ] > mem_free_threshold or x [ ' device ' ] in currently_active_devices ) , devices ) )
2022-11-14 09:13:37 +01:00
devices = list ( map ( lambda x : x [ ' device ' ] , devices ) )
2022-11-14 06:53:22 +01:00
return devices
2022-12-07 17:45:35 +01:00
def device_init ( context , device ) :
2022-11-14 06:53:22 +01:00
'''
This function assumes the ' device ' has already been verified to be compatible .
` get_device_delta ( ) ` has already filtered out incompatible devices .
'''
validate_device_id ( device , log_prefix = ' device_init ' )
if device == ' cpu ' :
2022-12-07 17:45:35 +01:00
context . device = ' cpu '
context . device_name = get_processor_name ( )
2022-12-11 13:46:29 +01:00
context . half_precision = False
2022-12-09 17:00:18 +01:00
log . debug ( f ' Render device CPU available as { context . device_name } ' )
2022-11-14 06:53:22 +01:00
return
2022-12-07 17:45:35 +01:00
context . device_name = torch . cuda . get_device_name ( device )
context . device = device
2022-11-14 06:53:22 +01:00
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
2022-12-09 13:20:33 +01:00
if needs_to_force_full_precision ( context ) :
2022-12-09 17:00:18 +01:00
log . warn ( f ' forcing full precision on this GPU, to avoid green images. GPU detected: { context . device_name } ' )
2022-11-14 06:53:22 +01:00
# Apply force_full_precision now before models are loaded.
2022-12-11 13:46:29 +01:00
context . half_precision = False
2022-11-14 06:53:22 +01:00
2022-12-09 17:00:18 +01:00
log . info ( f ' Setting { device } as active ' )
2022-11-14 06:53:22 +01:00
torch . cuda . device ( device )
return
2022-12-08 17:09:09 +01:00
def needs_to_force_full_precision ( context ) :
2022-12-11 13:46:29 +01:00
if ' FORCE_FULL_PRECISION ' in os . environ :
return True
2022-12-08 17:09:09 +01:00
device_name = context . device_name . lower ( )
return ( ( ' nvidia ' in device_name or ' geforce ' in device_name ) and ( ' 1660 ' in device_name or ' 1650 ' in device_name ) ) or ( ' Quadro T2000 ' in device_name )
2022-11-14 06:53:22 +01:00
def validate_device_id ( device , log_prefix = ' ' ) :
def is_valid ( ) :
if not isinstance ( device , str ) :
return False
if device == ' cpu ' :
return True
if not device . startswith ( ' cuda: ' ) or not device [ 5 : ] . isnumeric ( ) :
return False
return True
if not is_valid ( ) :
raise EnvironmentError ( f " { log_prefix } : device id should be ' cpu ' , or ' cuda:N ' (where N is an integer index for the GPU). Got: { device } " )
def is_device_compatible ( device ) :
'''
Returns True / False , and prints any compatibility errors
'''
2022-11-15 08:11:10 +01:00
try :
validate_device_id ( device , log_prefix = ' is_device_compatible ' )
except :
2022-12-09 17:00:18 +01:00
log . error ( str ( e ) )
2022-11-15 08:11:10 +01:00
return False
2022-11-14 06:53:22 +01:00
if device == ' cpu ' : return True
# Memory check
try :
_ , mem_total = torch . cuda . mem_get_info ( device )
mem_total / = float ( 10 * * 9 )
if mem_total < 3.0 :
2022-12-09 17:00:18 +01:00
log . warn ( f ' GPU { device } with less than 3 GB of VRAM is not compatible with Stable Diffusion ' )
2022-11-14 06:53:22 +01:00
return False
except RuntimeError as e :
2022-12-09 17:00:18 +01:00
log . error ( str ( e ) )
2022-11-14 06:53:22 +01:00
return False
return True
def get_processor_name ( ) :
try :
import platform , subprocess
if platform . system ( ) == " Windows " :
return platform . processor ( )
elif platform . system ( ) == " Darwin " :
os . environ [ ' PATH ' ] = os . environ [ ' PATH ' ] + os . pathsep + ' /usr/sbin '
command = " sysctl -n machdep.cpu.brand_string "
return subprocess . check_output ( command ) . strip ( )
elif platform . system ( ) == " Linux " :
command = " cat /proc/cpuinfo "
all_info = subprocess . check_output ( command , shell = True ) . decode ( ) . strip ( )
for line in all_info . split ( " \n " ) :
if " model name " in line :
return re . sub ( " .*model name.*: " , " " , line , 1 ) . strip ( )
except :
2022-12-09 17:00:18 +01:00
log . error ( traceback . format_exc ( ) )
2022-11-14 06:53:22 +01:00
return " cpu "