diff --git a/ui/easydiffusion/device_manager.py b/ui/easydiffusion/device_manager.py index d8dc4d86..d47807f3 100644 --- a/ui/easydiffusion/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -35,7 +35,7 @@ def get_device_delta(render_devices, active_devices): render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices)) if len(render_devices) == 0: raise Exception( - 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' + 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}' ) render_devices = list(filter(lambda x: is_device_compatible(x), render_devices)) @@ -64,13 +64,24 @@ def get_device_delta(render_devices, active_devices): return devices_to_start, devices_to_stop +def is_mps_available(): + return platform.system() == "Darwin" and \ + hasattr(torch.backends, 'mps') and \ + torch.backends.mps.is_available() and \ + torch.backends.mps.is_built() + + +def is_cuda_available(): + return torch.cuda.is_available() + + def auto_pick_devices(currently_active_devices): global mem_free_threshold - if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built(): + if is_mps_available(): return ["mps"] - if not torch.cuda.is_available(): + if not is_cuda_available(): return ["cpu"] device_count = torch.cuda.device_count() @@ -162,8 +173,6 @@ def needs_to_force_full_precision(context): def get_max_vram_usage_level(device): if "cuda" in device: _, mem_total = torch.cuda.mem_get_info(device) - elif device == "mps": - mem_total = torch.mps.driver_allocated_memory() else: return "high" @@ -204,7 +213,7 @@ def is_device_compatible(device): log.error(str(e)) return False - if device == "cpu" or device == "mps": + if not is_cuda_available(): return True # Memory check try: diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 42ce6f87..80c60323 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -385,16 +385,10 @@ def get_devices(): } def get_device_info(device): - if device == "cpu": + if not device_manager.is_cuda_available(): return {"name": device_manager.get_processor_name()} - if device == "mps": - mem_used = torch.mps.current_allocated_memory() - mem_total = torch.mps.driver_allocated_memory() - mem_free = mem_total - mem_used - else: - mem_free, mem_total = torch.cuda.mem_get_info(device) - + mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free /= float(10**9) mem_total /= float(10**9) @@ -414,7 +408,7 @@ def get_devices(): devices["all"].update({device: get_device_info(device)}) - if torch.backends.mps.is_available() and torch.backends.mps.is_built(): + if device_manager.is_mps_available(): devices["all"].update({"mps": get_device_info("mps")}) devices["all"].update({"cpu": get_device_info("cpu")})