diff --git a/ui/easydiffusion/device_manager.py b/ui/easydiffusion/device_manager.py index 8c486dd2..d8dc4d86 100644 --- a/ui/easydiffusion/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -22,17 +22,17 @@ mem_free_threshold = 0 def get_device_delta(render_devices, active_devices): """ - render_devices: 'cpu', or 'auto' or ['cuda:N'...] - active_devices: ['cpu', 'cuda:N'...] + render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...] + active_devices: ['cpu', 'mps', 'cuda:N'...] """ - if render_devices in ("cpu", "auto"): + if render_devices in ("cpu", "auto", "mps"): render_devices = [render_devices] elif render_devices is not None: if isinstance(render_devices, str): render_devices = [render_devices] if isinstance(render_devices, list) and len(render_devices) > 0: - render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices)) + render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices)) if len(render_devices) == 0: raise Exception( 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' @@ -160,14 +160,18 @@ def needs_to_force_full_precision(context): def get_max_vram_usage_level(device): - if device != "cpu": + if "cuda" in device: _, mem_total = torch.cuda.mem_get_info(device) - mem_total /= float(10**9) + elif device == "mps": + mem_total = torch.mps.driver_allocated_memory() + else: + return "high" - if mem_total < 4.5: - return "low" - elif mem_total < 6.5: - return "balanced" + mem_total /= float(10**9) + if mem_total < 4.5: + return "low" + elif mem_total < 6.5: + return "balanced" return "high" @@ -200,7 +204,7 @@ def is_device_compatible(device): log.error(str(e)) return False - if device == "cpu": + if device == "cpu" or device == "mps": return True # Memory check try: diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index d2e112ac..42ce6f87 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -385,10 +385,16 @@ def get_devices(): } def get_device_info(device): - if "cuda" not in device: + if device == "cpu": return {"name": device_manager.get_processor_name()} - mem_free, mem_total = torch.cuda.mem_get_info(device) + if device == "mps": + mem_used = torch.mps.current_allocated_memory() + mem_total = torch.mps.driver_allocated_memory() + mem_free = mem_total - mem_used + else: + mem_free, mem_total = torch.cuda.mem_get_info(device) + mem_free /= float(10**9) mem_total /= float(10**9) @@ -400,14 +406,17 @@ def get_devices(): } # list the compatible devices - gpu_count = torch.cuda.device_count() - for device in range(gpu_count): + cuda_count = torch.cuda.device_count() + for device in range(cuda_count): device = f"cuda:{device}" if not device_manager.is_device_compatible(device): continue devices["all"].update({device: get_device_info(device)}) + if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + devices["all"].update({"mps": get_device_info("mps")}) + devices["all"].update({"cpu": get_device_info("cpu")}) # list the activated devices