Free up VRAM when possible

This commit is contained in:
cmdr2 2022-09-21 21:53:25 +05:30
parent 6f60e71ea4
commit 7d12dbd4b2

View File

@ -186,7 +186,14 @@ def load_model_real_esrgan(real_esrgan_to_use):
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
def mk_img(req: Request):
global modelFS, device
try:
yield from do_mk_img(req)
except Exception as e:
gc()
raise e
def do_mk_img(req: Request):
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
@ -204,6 +211,7 @@ def mk_img(req: Request):
device = 'cpu'
if model_is_half:
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device)
load_model_gfpgan(gfpgan_file)
@ -218,7 +226,8 @@ def mk_img(req: Request):
(req.init_image is None and model_fs_is_half) or \
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
load_model_ckpt(ckpt_file, device, model.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
del model, modelCS, modelFS
load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
if prev_device != device:
load_model_gfpgan(gfpgan_file)
@ -363,6 +372,9 @@ def mk_img(req: Request):
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
@ -430,6 +442,8 @@ def mk_img(req: Request):
if opt_save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
del img
if has_filters and not stop_processing:
print('Applying filters..')
@ -459,15 +473,18 @@ def mk_img(req: Request):
save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
del filtered_image
seeds += str(opt_seed) + ","
opt_seed += 1
gc()
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
del x_samples
del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print('Task completed')