diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 3deca3be..ed6e7c80 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -186,7 +186,14 @@ def load_model_real_esrgan(real_esrgan_to_use): print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision) def mk_img(req: Request): - global modelFS, device + try: + yield from do_mk_img(req) + except Exception as e: + gc() + raise e + +def do_mk_img(req: Request): + global model, modelCS, modelFS, device global model_gfpgan, model_real_esrgan global stop_processing @@ -204,6 +211,7 @@ def mk_img(req: Request): device = 'cpu' if model_is_half: + del model, modelCS, modelFS load_model_ckpt(ckpt_file, device) load_model_gfpgan(gfpgan_file) @@ -218,7 +226,8 @@ def mk_img(req: Request): (req.init_image is None and model_fs_is_half) or \ (req.init_image is not None and not model_fs_is_half and not force_full_precision): - load_model_ckpt(ckpt_file, device, model.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision)) + del model, modelCS, modelFS + load_model_ckpt(ckpt_file, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision)) if prev_device != device: load_model_gfpgan(gfpgan_file) @@ -363,6 +372,9 @@ def mk_img(req: Request): img.save(buf, format='JPEG') buf.seek(0) + del img, x_sample, x_samples_ddim + # don't delete x_samples, it is used in the code that called this callback + temp_images[str(req.session_id) + '/' + str(i)] = buf partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) @@ -430,6 +442,8 @@ def mk_img(req: Request): if opt_save_to_disk_path is not None: res_image_orig.path_abs = img_out_path + del img + if has_filters and not stop_processing: print('Applying filters..') @@ -459,15 +473,18 @@ def mk_img(req: Request): save_image(filtered_image, filtered_img_out_path) res_image_filtered.path_abs = filtered_img_out_path + del filtered_image + seeds += str(opt_seed) + "," opt_seed += 1 + gc() if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelFS.to("cpu") while torch.cuda.memory_allocated() / 1e6 >= mem: time.sleep(1) - del x_samples + del x_samples, x_samples_ddim, x_sample print("memory_final = ", torch.cuda.memory_allocated() / 1e6) print('Task completed')