More comments and cleanup.

This commit is contained in:
Marc-Andre Ferland 2022-10-21 20:56:24 -04:00
parent 88ef1a3c5b
commit 7befa94e6d

View File

@ -207,6 +207,9 @@ def load_model_ckpt():
model.turbo = thread_data.turbo
if thread_data.device != 'cpu':
model.to(thread_data.device)
#if thread_data.reduced_memory:
#model.model1.to("cpu")
#model.model2.to("cpu")
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage)
@ -263,9 +266,8 @@ def unload_models():
thread_data.modelCS = None
thread_data.modelFS = None
def wait_move(model, target_device=None): # Send to target_device and wait until complete.
if thread_data.device == "cpu" or thread_data.device == target_device: return
if target_device is None: target_device = 'cpu'
def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
if thread_data.device == target_device: return
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
if start_mem <= 0: return
model_name = model.__class__.__name__
@ -338,12 +340,11 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
def apply_filters(filter_name, image_data, model_path=None):
print(f'Applying filter {filter_name}...')
gc() # Free space before loading new data.
if isinstance(image_data, torch.Tensor):
print(image_data)
image_data.to(thread_data.device)
gc()
if filter_name == 'gfpgan':
if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
@ -373,18 +374,10 @@ def mk_img(req: Request):
yield from do_mk_img(req)
except Exception as e:
print(traceback.format_exc())
gc()
if thread_data.device != "cpu":
thread_data.modelFS.to("cpu")
thread_data.modelCS.to("cpu")
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
gc()
# Model crashed, release all resources in unknown state.
unload_models()
unload_filters()
gc() # Release from memory.
yield json.dumps({
"status": 'failed',
"detail": str(e)
@ -471,6 +464,7 @@ def do_mk_img(req: Request):
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
# Start by cleaning memory, loading and unloading things can leave memory allocated.
gc()
opt_prompt = req.prompt
@ -525,7 +519,8 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half()
wait_move(thread_data.modelFS) # Send to CPU and wait until complete.
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelFS, 'cpu')
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -607,7 +602,8 @@ def do_mk_img(req: Request):
del x_samples, x_samples_ddim, x_sample
if thread_data.reduced_memory:
wait_move(thread_data.modelFS) # Send to CPU and wait until complete.
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelFS, 'cpu')
print("saving images")
for i in range(batch_size):
@ -699,7 +695,8 @@ Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
wait_move(thread_data.modelCS) # Send to CPU and wait until complete.
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelCS, 'cpu')
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)