diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index cd9c6a3c..18631469 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -5,6 +5,8 @@ import re COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked +mem_free_threshold = 0 + def get_device_delta(render_devices, active_devices): ''' render_devices: 'cpu', or 'auto' or ['cuda:N'...] @@ -41,6 +43,8 @@ def get_device_delta(render_devices, active_devices): return devices_to_start, devices_to_stop def auto_pick_devices(currently_active_devices): + global mem_free_threshold + if not torch.cuda.is_available(): return ['cpu'] device_count = torch.cuda.device_count() @@ -62,8 +66,9 @@ def auto_pick_devices(currently_active_devices): devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) devices.sort(key=lambda x:x['mem_free'], reverse=True) - max_free_mem = devices[0]['mem_free'] - free_mem_threshold = COMPARABLE_GPU_PERCENTILE * max_free_mem + max_mem_free = devices[0]['mem_free'] + curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free + mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold) # Auto-pick algorithm: # 1. Pick the top 75 percentile of the GPUs, sorted by free_mem. @@ -71,7 +76,7 @@ def auto_pick_devices(currently_active_devices): # always be very low (since their VRAM contains the model). # These already-running devices probably aren't terrible, since they were picked in the past. # Worst case, the user can restart the program and that'll get rid of them. - devices = list(filter((lambda x: x['mem_free'] > free_mem_threshold or x['device'] in currently_active_devices), devices)) + devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices)) devices = list(map(lambda x: x['device'], devices)) return devices