mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-10 11:08:20 +02:00
Speed up the model move, by using the earlier function to move modelCS and modelFS to the cpu
This commit is contained in:
parent
f7af259576
commit
c10e773401
@ -219,29 +219,36 @@ def unload_models():
|
|||||||
|
|
||||||
gc()
|
gc()
|
||||||
|
|
||||||
def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
|
# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
|
||||||
if thread_data.device == target_device: return
|
# if thread_data.device == target_device: return
|
||||||
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||||
if start_mem <= 0: return
|
# if start_mem <= 0: return
|
||||||
model_name = model.__class__.__name__
|
# model_name = model.__class__.__name__
|
||||||
print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb')
|
# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb')
|
||||||
start_time = time.time()
|
# start_time = time.time()
|
||||||
model.to(target_device)
|
# model.to(target_device)
|
||||||
time_step = start_time
|
# time_step = start_time
|
||||||
WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
|
# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
|
||||||
last_mem = start_mem
|
# last_mem = start_mem
|
||||||
is_transfering = True
|
# is_transfering = True
|
||||||
while is_transfering:
|
# while is_transfering:
|
||||||
time.sleep(0.5) # 500ms
|
# time.sleep(0.5) # 500ms
|
||||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||||
is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
|
# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
|
||||||
last_mem = mem
|
# last_mem = mem
|
||||||
if not is_transfering:
|
# if not is_transfering:
|
||||||
break;
|
# break;
|
||||||
if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
|
# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
|
||||||
print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb')
|
# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb')
|
||||||
time_step = time.time()
|
# time_step = time.time()
|
||||||
print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}')
|
# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}')
|
||||||
|
|
||||||
|
def move_to_cpu(model):
|
||||||
|
if thread_data.device != "cpu":
|
||||||
|
mem = torch.cuda.memory_allocated() / 1e6
|
||||||
|
model.to("cpu")
|
||||||
|
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
def load_model_gfpgan():
|
def load_model_gfpgan():
|
||||||
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
|
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
|
||||||
@ -484,7 +491,8 @@ def do_mk_img(req: Request):
|
|||||||
mask = mask.half()
|
mask = mask.half()
|
||||||
|
|
||||||
# Send to CPU and wait until complete.
|
# Send to CPU and wait until complete.
|
||||||
wait_model_move_to(thread_data.modelFS, 'cpu')
|
# wait_model_move_to(thread_data.modelFS, 'cpu')
|
||||||
|
move_to_cpu(thread_data.modelFS)
|
||||||
|
|
||||||
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||||
@ -560,10 +568,6 @@ def do_mk_img(req: Request):
|
|||||||
img_data[i] = x_sample
|
img_data[i] = x_sample
|
||||||
del x_samples, x_samples_ddim, x_sample
|
del x_samples, x_samples_ddim, x_sample
|
||||||
|
|
||||||
if thread_data.reduced_memory:
|
|
||||||
# Send to CPU and wait until complete.
|
|
||||||
wait_model_move_to(thread_data.modelFS, 'cpu')
|
|
||||||
|
|
||||||
print("saving images")
|
print("saving images")
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
img = Image.fromarray(img_data[i])
|
img = Image.fromarray(img_data[i])
|
||||||
@ -617,6 +621,7 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
# if thread_data.reduced_memory:
|
# if thread_data.reduced_memory:
|
||||||
# unload_filters()
|
# unload_filters()
|
||||||
|
move_to_cpu(thread_data.modelFS)
|
||||||
del img_data
|
del img_data
|
||||||
gc()
|
gc()
|
||||||
if thread_data.device != 'cpu':
|
if thread_data.device != 'cpu':
|
||||||
@ -656,7 +661,9 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
# Send to CPU and wait until complete.
|
# Send to CPU and wait until complete.
|
||||||
wait_model_move_to(thread_data.modelCS, 'cpu')
|
# wait_model_move_to(thread_data.modelCS, 'cpu')
|
||||||
|
|
||||||
|
move_to_cpu(thread_data.modelCS)
|
||||||
|
|
||||||
if sampler_name == 'ddim':
|
if sampler_name == 'ddim':
|
||||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user