Merge pull request #995 from cmdr2/beta

Beta
This commit is contained in:
cmdr2 2023-03-11 09:08:01 +05:30 committed by GitHub
commit bdb6649722
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 17 deletions

View File

@ -22,20 +22,20 @@ mem_free_threshold = 0
def get_device_delta(render_devices, active_devices): def get_device_delta(render_devices, active_devices):
""" """
render_devices: 'cpu', or 'auto' or ['cuda:N'...] render_devices: 'cpu', or 'auto', or 'mps' or ['cuda:N'...]
active_devices: ['cpu', 'cuda:N'...] active_devices: ['cpu', 'mps', 'cuda:N'...]
""" """
if render_devices in ("cpu", "auto"): if render_devices in ("cpu", "auto", "mps"):
render_devices = [render_devices] render_devices = [render_devices]
elif render_devices is not None: elif render_devices is not None:
if isinstance(render_devices, str): if isinstance(render_devices, str):
render_devices = [render_devices] render_devices = [render_devices]
if isinstance(render_devices, list) and len(render_devices) > 0: if isinstance(render_devices, list) and len(render_devices) > 0:
render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices)) render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices))
if len(render_devices) == 0: if len(render_devices) == 0:
raise Exception( raise Exception(
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}'
) )
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices)) render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
@ -64,13 +64,26 @@ def get_device_delta(render_devices, active_devices):
return devices_to_start, devices_to_stop return devices_to_start, devices_to_stop
def is_mps_available():
return (
platform.system() == "Darwin"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
)
def is_cuda_available():
return torch.cuda.is_available()
def auto_pick_devices(currently_active_devices): def auto_pick_devices(currently_active_devices):
global mem_free_threshold global mem_free_threshold
if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built(): if is_mps_available():
return ["mps"] return ["mps"]
if not torch.cuda.is_available(): if not is_cuda_available():
return ["cpu"] return ["cpu"]
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
@ -160,14 +173,16 @@ def needs_to_force_full_precision(context):
def get_max_vram_usage_level(device): def get_max_vram_usage_level(device):
if device != "cpu": if "cuda" in device:
_, mem_total = torch.cuda.mem_get_info(device) _, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) else:
return "high"
if mem_total < 4.5: mem_total /= float(10**9)
return "low" if mem_total < 4.5:
elif mem_total < 6.5: return "low"
return "balanced" elif mem_total < 6.5:
return "balanced"
return "high" return "high"
@ -200,7 +215,7 @@ def is_device_compatible(device):
log.error(str(e)) log.error(str(e))
return False return False
if device == "cpu": if device in ("cpu", "mps"):
return True return True
# Memory check # Memory check
try: try:

View File

@ -385,7 +385,7 @@ def get_devices():
} }
def get_device_info(device): def get_device_info(device):
if "cuda" not in device: if device in ("cpu", "mps"):
return {"name": device_manager.get_processor_name()} return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free, mem_total = torch.cuda.mem_get_info(device)
@ -400,14 +400,17 @@ def get_devices():
} }
# list the compatible devices # list the compatible devices
gpu_count = torch.cuda.device_count() cuda_count = torch.cuda.device_count()
for device in range(gpu_count): for device in range(cuda_count):
device = f"cuda:{device}" device = f"cuda:{device}"
if not device_manager.is_device_compatible(device): if not device_manager.is_device_compatible(device):
continue continue
devices["all"].update({device: get_device_info(device)}) devices["all"].update({device: get_device_info(device)})
if device_manager.is_mps_available():
devices["all"].update({"mps": get_device_info("mps")})
devices["all"].update({"cpu": get_device_info("cpu")}) devices["all"].update({"cpu": get_device_info("cpu")})
# list the activated devices # list the activated devices