From c10e7734014964147e4720b8a88c603cab909d88 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 19 Nov 2022 11:53:33 +0530 Subject: [PATCH] Speed up the model move, by using the earlier function to move modelCS and modelFS to the cpu --- ui/sd_internal/runtime.py | 65 ++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 486d2179..d0efff6a 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -219,29 +219,36 @@ def unload_models(): gc() -def wait_model_move_to(model, target_device): # Send to target_device and wait until complete. - if thread_data.device == target_device: return - start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 - if start_mem <= 0: return - model_name = model.__class__.__name__ - print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb') - start_time = time.time() - model.to(target_device) - time_step = start_time - WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. - last_mem = start_mem - is_transfering = True - while is_transfering: - time.sleep(0.5) # 500ms - mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 - is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. - last_mem = mem - if not is_transfering: - break; - if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. - print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb') - time_step = time.time() - print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}') +# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete. +# if thread_data.device == target_device: return +# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 +# if start_mem <= 0: return +# model_name = model.__class__.__name__ +# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb') +# start_time = time.time() +# model.to(target_device) +# time_step = start_time +# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. +# last_mem = start_mem +# is_transfering = True +# while is_transfering: +# time.sleep(0.5) # 500ms +# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 +# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. +# last_mem = mem +# if not is_transfering: +# break; +# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. +# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb') +# time_step = time.time() +# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}') + +def move_to_cpu(model): + if thread_data.device != "cpu": + mem = torch.cuda.memory_allocated() / 1e6 + model.to("cpu") + while torch.cuda.memory_allocated() / 1e6 >= mem: + time.sleep(1) def load_model_gfpgan(): if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') @@ -484,7 +491,8 @@ def do_mk_img(req: Request): mask = mask.half() # Send to CPU and wait until complete. - wait_model_move_to(thread_data.modelFS, 'cpu') + # wait_model_move_to(thread_data.modelFS, 'cpu') + move_to_cpu(thread_data.modelFS) assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(req.prompt_strength * req.num_inference_steps) @@ -560,10 +568,6 @@ def do_mk_img(req: Request): img_data[i] = x_sample del x_samples, x_samples_ddim, x_sample - if thread_data.reduced_memory: - # Send to CPU and wait until complete. - wait_model_move_to(thread_data.modelFS, 'cpu') - print("saving images") for i in range(batch_size): img = Image.fromarray(img_data[i]) @@ -617,6 +621,7 @@ def do_mk_img(req: Request): # if thread_data.reduced_memory: # unload_filters() + move_to_cpu(thread_data.modelFS) del img_data gc() if thread_data.device != 'cpu': @@ -656,7 +661,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] # Send to CPU and wait until complete. - wait_model_move_to(thread_data.modelCS, 'cpu') + # wait_model_move_to(thread_data.modelCS, 'cpu') + + move_to_cpu(thread_data.modelCS) if sampler_name == 'ddim': thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)