Memory improvements

This commit is contained in:
Marc-Andre Ferland 2022-10-21 03:53:43 -04:00
parent 1442748f58
commit ccb7a553c2

View File

@ -9,6 +9,7 @@ import os, re
import traceback
import torch
import numpy as np
from gc import collect as gc_collect
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from tqdm import tqdm, trange
@ -104,6 +105,7 @@ def device_init(device_selection=None):
thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False
thread_data.reduced_memory = True
if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
@ -172,7 +174,6 @@ def load_model_ckpt():
if not thread_data.unet_bs:
thread_data.unet_bs = 1
unload_model()
if thread_data.device == 'cpu':
thread_data.precision = 'full'
@ -213,14 +214,20 @@ def load_model_ckpt():
modelCS.eval()
modelCS.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != 'cpu':
modelCS.to(thread_data.device)
if thread_data.reduced_memory:
modelCS.to('cpu')
else:
modelCS.to(thread_data.device) # Preload on device if not already there.
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
if thread_data.device != 'cpu':
modelFS.to(thread_data.device)
if thread_data.reduced_memory:
modelFS.to('cpu')
else:
modelFS.to(thread_data.device) # Preload on device if not already there.
thread_data.modelFS = modelFS
del sd
@ -236,20 +243,55 @@ def load_model_ckpt():
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
def unload_model():
def unload_filters():
if thread_data.model_gfpgan is not None:
del thread_data.model_gfpgan
thread_data.model_gfpgan = None
if thread_data.model_real_esrgan is not None:
del thread_data.model_real_esrgan
thread_data.model_real_esrgan = None
def unload_models():
if thread_data.model is not None:
print('Unloading models...')
del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
def wait_move(model, target_device=None): # Send to target_device and wait until complete.
if thread_data.device == "cpu" or thread_data.device == target_device: return
if target_device is None: target_device = 'cpu'
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
if start_mem <= 0: return
model_name = model.__class__.__name__
print(f'Device:{thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mo')
model.to(target_device)
start_time = time.time()
time_step = start_time
WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
last_mem = start_mem
is_transfering = True
while is_transfering:
time.sleep(0.5) # 500ms
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
last_mem = mem
if not is_transfering:
break;
if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
print(f'Device:{thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mo, Transfered: {round(start_mem - mem)}Mo')
time_step = time.time()
print(f'Device:{thread_data.device} - {model_name} Moved: {round(start_mem - mem)}Mo in {round(time.time() - start_time, 3)} seconds to {target_device}')
def load_model_gfpgan():
if thread_data.gfpgan_file is None:
print('load_model_gfpgan called without setting gfpgan_file')
return
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
#print('load_model_gfpgan called without setting gfpgan_file')
#return
if not is_first_cuda_device(thread_data.device):
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}. Cannot run GFPGANer.')
@ -258,9 +300,9 @@ def load_model_gfpgan():
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
def load_model_real_esrgan():
if thread_data.real_esrgan_file is None:
print('load_model_real_esrgan called without setting real_esrgan_file')
return
if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
#print('load_model_real_esrgan called without setting real_esrgan_file')
#return
model_path = thread_data.real_esrgan_file + ".pth"
RealESRGAN_models = {
@ -294,7 +336,7 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
def apply_filters(filter_name, image_data):
def apply_filters(filter_name, image_data, model_path=None):
print(f'Applying filter {filter_name}...')
if isinstance(image_data, torch.Tensor):
print(image_data)
@ -303,12 +345,22 @@ def apply_filters(filter_name, image_data):
gc()
if filter_name == 'gfpgan':
if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
if model_path is not None and model_path != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = model_path
load_model_real_esrgan()
elif not thread_data.model_real_esrgan:
load_model_real_esrgan()
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
@ -338,6 +390,53 @@ def mk_img(req: Request):
"detail": str(e)
})
def update_temp_img(req, x_samples):
partial_images = []
for i in range(req.num_outputs):
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progess_generator(req, extra_props=None):
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
return empty_callback
thread_data.partial_x_samples = None
last_callback_time = -1
def img_callback(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time}
if extra_props is not None:
progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples)
yield json.dumps(progress)
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return img_callback
def do_mk_img(req: Request):
thread_data.stop_processing = False
@ -353,7 +452,7 @@ def do_mk_img(req: Request):
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if thread_data.ckpt_file != req.use_stable_diffusion_model:
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
needs_model_reload = True
@ -361,25 +460,19 @@ def do_mk_img(req: Request):
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
load_model_ckpt()
load_model_gfpgan()
load_model_real_esrgan()
needs_model_reload = False
needs_model_reload = True
if needs_model_reload:
unload_models()
unload_filters()
load_model_ckpt()
if req.use_face_correction is not None and req.use_face_correction != thread_data.gfpgan_file:
thread_data.gfpgan_file = req.use_face_correction
load_model_gfpgan()
if req.use_upscale is not None and req.use_upscale != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = req.use_upscale
load_model_real_esrgan()
if thread_data.turbo != req.turbo:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
gc()
opt_prompt = req.prompt
opt_seed = req.seed
opt_n_iter = 1
@ -432,7 +525,7 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half()
move_fs_to_cpu()
wait_move(thread_data.modelFS) # Send to CPU and wait until complete.
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -450,7 +543,8 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
thread_data.modelCS.to(thread_data.device)
if thread_data.reduced_memory:
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
@ -470,47 +564,11 @@ def do_mk_img(req: Request):
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
thread_data.modelFS.to(thread_data.device)
if thread_data.reduced_memory:
thread_data.modelFS.to(thread_data.device)
partial_x_samples = None
last_callback_time = -1
def img_callback(x_samples, i):
nonlocal partial_x_samples, last_callback_time
partial_x_samples = x_samples
if req.stream_progress_updates:
n_steps = req.num_inference_steps if req.init_image is None else t_enc
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "total_steps": n_steps, "step_time": step_time}
if req.stream_image_progress and i % 5 == 0:
partial_images = []
for i in range(batch_size):
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progess_generator(req, {"total_steps": n_steps})
# run the handler
try:
@ -520,16 +578,23 @@ def do_mk_img(req: Request):
else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
yield from x_samples
x_samples = partial_x_samples
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop:
if partial_x_samples is None:
if not hasattr(thread_data, 'partial_x_samples'):
continue
if thread_data.partial_x_samples is None:
del thread_data.partial_x_samples
continue
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
x_samples = partial_x_samples
print("saving images")
print("decoding images")
img_data = [None] * batch_size
for i in range(batch_size):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
@ -538,7 +603,15 @@ def do_mk_img(req: Request):
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
img_data[i] = x_sample
del x_samples, x_samples_ddim, x_sample
if thread_data.reduced_memory:
wait_move(thread_data.modelFS) # Send to CPU and wait until complete.
print("saving images")
for i in range(batch_size):
img = Image.fromarray(img_data[i])
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
@ -562,19 +635,18 @@ def do_mk_img(req: Request):
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
del img
if has_filters and not thread_data.stop_processing:
filters_applied = []
if req.use_face_correction:
x_sample = apply_filters('gfpgan', x_sample)
img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
filters_applied.append(req.use_face_correction)
if req.use_upscale:
x_sample = apply_filters('real_esrgan', x_sample)
img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0):
filtered_image = Image.fromarray(x_sample)
filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image)
@ -587,9 +659,10 @@ def do_mk_img(req: Request):
seeds += str(opt_seed) + ","
opt_seed += 1
move_fs_to_cpu()
if thread_data.reduced_memory:
unload_filters()
del img_data
gc()
del x_samples, x_samples_ddim, x_sample
if thread_data.device != 'cpu':
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
@ -626,15 +699,7 @@ Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
thread_data.modelCS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
wait_move(thread_data.modelCS) # Send to CPU and wait until complete.
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
@ -677,21 +742,10 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T=x_T,
sampler = 'ddim'
)
yield from samples_ddim
def move_fs_to_cpu():
if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
thread_data.modelFS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
def gc():
#gc.collect()
gc_collect()
if thread_data.device == 'cpu':
return
torch.cuda.empty_cache()