diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index f2f8d169..26c116ad 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -28,6 +28,8 @@ from gfpgan import GFPGANer from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer +from threading import Lock + import uuid logging.set_verbosity_error() @@ -35,7 +37,7 @@ logging.set_verbosity_error() # consts config_yaml = "optimizedSD/v1-inference.yaml" filename_regex = re.compile('[^a-zA-Z0-9]') -force_gfpgan_to_cuda0 = True # workaround: gfpgan currently works only on cuda:0 +gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. # api stuff from sd_internal import device_manager @@ -309,12 +311,6 @@ def move_to_cpu(model): def load_model_gfpgan(): if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') - - # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files - from facexlib.detection import retinaface - retinaface.device = torch.device(thread_data.device) - print('forced retinaface.device to', thread_data.device) - model_path = thread_data.gfpgan_file + ".pth" thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) @@ -370,15 +366,23 @@ def apply_filters(filter_name, image_data, model_path=None): image_data.to(thread_data.device) if filter_name == 'gfpgan': - if model_path is not None and model_path != thread_data.gfpgan_file: - thread_data.gfpgan_file = model_path - load_model_gfpgan() - elif not thread_data.model_gfpgan: - load_model_gfpgan() - if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') - print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - image_data = output[:,:,::-1] + # This lock is only ever used here. No need to use timeout for the request. Should never deadlock. + with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. + # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files + from facexlib.detection import retinaface + retinaface.device = torch.device(thread_data.device) + print('forced retinaface.device to', thread_data.device) + + if model_path is not None and model_path != thread_data.gfpgan_file: + thread_data.gfpgan_file = model_path + load_model_gfpgan() + elif not thread_data.model_gfpgan: + load_model_gfpgan() + if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') + + print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) + _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) + image_data = output[:,:,::-1] if filter_name == 'real_esrgan': if model_path is not None and model_path != thread_data.real_esrgan_file: