mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-02 07:14:27 +02:00
Report the device GPU memory (and existence) correctly for mps (mac)
This commit is contained in:
parent
1b7af75d4e
commit
d1a45ed9ac
@ -22,17 +22,17 @@ mem_free_threshold = 0
|
|||||||
|
|
||||||
def get_device_delta(render_devices, active_devices):
|
def get_device_delta(render_devices, active_devices):
|
||||||
"""
|
"""
|
||||||
render_devices: 'cpu', or 'auto' or ['cuda:N'...]
|
render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...]
|
||||||
active_devices: ['cpu', 'cuda:N'...]
|
active_devices: ['cpu', 'mps', 'cuda:N'...]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if render_devices in ("cpu", "auto"):
|
if render_devices in ("cpu", "auto", "mps"):
|
||||||
render_devices = [render_devices]
|
render_devices = [render_devices]
|
||||||
elif render_devices is not None:
|
elif render_devices is not None:
|
||||||
if isinstance(render_devices, str):
|
if isinstance(render_devices, str):
|
||||||
render_devices = [render_devices]
|
render_devices = [render_devices]
|
||||||
if isinstance(render_devices, list) and len(render_devices) > 0:
|
if isinstance(render_devices, list) and len(render_devices) > 0:
|
||||||
render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices))
|
render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices))
|
||||||
if len(render_devices) == 0:
|
if len(render_devices) == 0:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||||
@ -160,10 +160,14 @@ def needs_to_force_full_precision(context):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_vram_usage_level(device):
|
def get_max_vram_usage_level(device):
|
||||||
if device != "cpu":
|
if "cuda" in device:
|
||||||
_, mem_total = torch.cuda.mem_get_info(device)
|
_, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_total /= float(10**9)
|
elif device == "mps":
|
||||||
|
mem_total = torch.mps.driver_allocated_memory()
|
||||||
|
else:
|
||||||
|
return "high"
|
||||||
|
|
||||||
|
mem_total /= float(10**9)
|
||||||
if mem_total < 4.5:
|
if mem_total < 4.5:
|
||||||
return "low"
|
return "low"
|
||||||
elif mem_total < 6.5:
|
elif mem_total < 6.5:
|
||||||
@ -200,7 +204,7 @@ def is_device_compatible(device):
|
|||||||
log.error(str(e))
|
log.error(str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu" or device == "mps":
|
||||||
return True
|
return True
|
||||||
# Memory check
|
# Memory check
|
||||||
try:
|
try:
|
||||||
|
@ -385,10 +385,16 @@ def get_devices():
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_device_info(device):
|
def get_device_info(device):
|
||||||
if "cuda" not in device:
|
if device == "cpu":
|
||||||
return {"name": device_manager.get_processor_name()}
|
return {"name": device_manager.get_processor_name()}
|
||||||
|
|
||||||
|
if device == "mps":
|
||||||
|
mem_used = torch.mps.current_allocated_memory()
|
||||||
|
mem_total = torch.mps.driver_allocated_memory()
|
||||||
|
mem_free = mem_total - mem_used
|
||||||
|
else:
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||||
|
|
||||||
mem_free /= float(10**9)
|
mem_free /= float(10**9)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
|
|
||||||
@ -400,14 +406,17 @@ def get_devices():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# list the compatible devices
|
# list the compatible devices
|
||||||
gpu_count = torch.cuda.device_count()
|
cuda_count = torch.cuda.device_count()
|
||||||
for device in range(gpu_count):
|
for device in range(cuda_count):
|
||||||
device = f"cuda:{device}"
|
device = f"cuda:{device}"
|
||||||
if not device_manager.is_device_compatible(device):
|
if not device_manager.is_device_compatible(device):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
devices["all"].update({device: get_device_info(device)})
|
devices["all"].update({device: get_device_info(device)})
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
||||||
|
devices["all"].update({"mps": get_device_info("mps")})
|
||||||
|
|
||||||
devices["all"].update({"cpu": get_device_info("cpu")})
|
devices["all"].update({"cpu": get_device_info("cpu")})
|
||||||
|
|
||||||
# list the activated devices
|
# list the activated devices
|
||||||
|
Loading…
Reference in New Issue
Block a user