Merge pull request #987 from michaelgallacher/beta

Hotfix rollup
This commit is contained in:
cmdr2 2023-03-10 10:16:21 +05:30 committed by GitHub
commit 8907dabd4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 15 deletions

View File

@ -35,7 +35,7 @@ def get_device_delta(render_devices, active_devices):
render_devices = list(filter(lambda x: x.startswith("cuda:") or x == "mps", render_devices))
if len(render_devices) == 0:
raise Exception(
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "mps"} or {"render_devices": "auto"}'
)
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
@ -64,13 +64,24 @@ def get_device_delta(render_devices, active_devices):
return devices_to_start, devices_to_stop
def is_mps_available():
return platform.system() == "Darwin" and \
hasattr(torch.backends, 'mps') and \
torch.backends.mps.is_available() and \
torch.backends.mps.is_built()
def is_cuda_available():
return torch.cuda.is_available()
def auto_pick_devices(currently_active_devices):
global mem_free_threshold
if platform.system() == "Darwin" and torch.backends.mps.is_available() and torch.backends.mps.is_built():
if is_mps_available():
return ["mps"]
if not torch.cuda.is_available():
if not is_cuda_available():
return ["cpu"]
device_count = torch.cuda.device_count()
@ -162,8 +173,6 @@ def needs_to_force_full_precision(context):
def get_max_vram_usage_level(device):
if "cuda" in device:
_, mem_total = torch.cuda.mem_get_info(device)
elif device == "mps":
mem_total = torch.mps.driver_allocated_memory()
else:
return "high"
@ -204,7 +213,7 @@ def is_device_compatible(device):
log.error(str(e))
return False
if device == "cpu" or device == "mps":
if not is_cuda_available():
return True
# Memory check
try:

View File

@ -385,16 +385,10 @@ def get_devices():
}
def get_device_info(device):
if device == "cpu":
if not device_manager.is_cuda_available():
return {"name": device_manager.get_processor_name()}
if device == "mps":
mem_used = torch.mps.current_allocated_memory()
mem_total = torch.mps.driver_allocated_memory()
mem_free = mem_total - mem_used
else:
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
@ -414,7 +408,7 @@ def get_devices():
devices["all"].update({device: get_device_info(device)})
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
if device_manager.is_mps_available():
devices["all"].update({"mps": get_device_info("mps")})
devices["all"].update({"cpu": get_device_info("cpu")})