diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index b88af09f..132658e1 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -9,6 +9,7 @@ import os, re import traceback import torch import numpy as np +from gc import collect as gc_collect from omegaconf import OmegaConf from PIL import Image, ImageOps from tqdm import tqdm, trange @@ -104,6 +105,7 @@ def device_init(device_selection=None): thread_data.turbo = False thread_data.has_valid_gpu = False thread_data.force_full_precision = False + thread_data.reduced_memory = True if device_selection.lower() == 'cpu': print('CPU requested, skipping gpu init.') @@ -172,7 +174,6 @@ def load_model_ckpt(): if not thread_data.unet_bs: thread_data.unet_bs = 1 - unload_model() if thread_data.device == 'cpu': thread_data.precision = 'full' @@ -213,14 +214,20 @@ def load_model_ckpt(): modelCS.eval() modelCS.cond_stage_model.device = torch.device(thread_data.device) if thread_data.device != 'cpu': - modelCS.to(thread_data.device) + if thread_data.reduced_memory: + modelCS.to('cpu') + else: + modelCS.to(thread_data.device) # Preload on device if not already there. thread_data.modelCS = modelCS modelFS = instantiate_from_config(config.modelFirstStage) _, _ = modelFS.load_state_dict(sd, strict=False) modelFS.eval() if thread_data.device != 'cpu': - modelFS.to(thread_data.device) + if thread_data.reduced_memory: + modelFS.to('cpu') + else: + modelFS.to(thread_data.device) # Preload on device if not already there. thread_data.modelFS = modelFS del sd @@ -236,20 +243,55 @@ def load_model_ckpt(): print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision) -def unload_model(): +def unload_filters(): + if thread_data.model_gfpgan is not None: + del thread_data.model_gfpgan + thread_data.model_gfpgan = None + + if thread_data.model_real_esrgan is not None: + del thread_data.model_real_esrgan + thread_data.model_real_esrgan = None + +def unload_models(): if thread_data.model is not None: print('Unloading models...') del thread_data.model del thread_data.modelCS del thread_data.modelFS + thread_data.model = None thread_data.modelCS = None thread_data.modelFS = None +def wait_move(model, target_device=None): # Send to target_device and wait until complete. + if thread_data.device == "cpu" or thread_data.device == target_device: return + if target_device is None: target_device = 'cpu' + start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 + if start_mem <= 0: return + model_name = model.__class__.__name__ + print(f'Device:{thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mo') + model.to(target_device) + start_time = time.time() + time_step = start_time + WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. + last_mem = start_mem + is_transfering = True + while is_transfering: + time.sleep(0.5) # 500ms + mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 + is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. + last_mem = mem + if not is_transfering: + break; + if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. + print(f'Device:{thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mo, Transfered: {round(start_mem - mem)}Mo') + time_step = time.time() + print(f'Device:{thread_data.device} - {model_name} Moved: {round(start_mem - mem)}Mo in {round(time.time() - start_time, 3)} seconds to {target_device}') + def load_model_gfpgan(): - if thread_data.gfpgan_file is None: - print('load_model_gfpgan called without setting gfpgan_file') - return + if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') + #print('load_model_gfpgan called without setting gfpgan_file') + #return if not is_first_cuda_device(thread_data.device): #TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices. raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}. Cannot run GFPGANer.') @@ -258,9 +300,9 @@ def load_model_gfpgan(): print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) def load_model_real_esrgan(): - if thread_data.real_esrgan_file is None: - print('load_model_real_esrgan called without setting real_esrgan_file') - return + if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.') + #print('load_model_real_esrgan called without setting real_esrgan_file') + #return model_path = thread_data.real_esrgan_file + ".pth" RealESRGAN_models = { @@ -294,7 +336,7 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None): return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") -def apply_filters(filter_name, image_data): +def apply_filters(filter_name, image_data, model_path=None): print(f'Applying filter {filter_name}...') if isinstance(image_data, torch.Tensor): print(image_data) @@ -303,12 +345,22 @@ def apply_filters(filter_name, image_data): gc() if filter_name == 'gfpgan': + if model_path is not None and model_path != thread_data.gfpgan_file: + thread_data.gfpgan_file = model_path + load_model_gfpgan() + elif not thread_data.model_gfpgan: + load_model_gfpgan() if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) image_data = output[:,:,::-1] if filter_name == 'real_esrgan': + if model_path is not None and model_path != thread_data.real_esrgan_file: + thread_data.real_esrgan_file = model_path + load_model_real_esrgan() + elif not thread_data.model_real_esrgan: + load_model_real_esrgan() if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.') print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1]) @@ -338,6 +390,53 @@ def mk_img(req: Request): "detail": str(e) }) +def update_temp_img(req, x_samples): + partial_images = [] + for i in range(req.num_outputs): + x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) + x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") + x_sample = x_sample.astype(np.uint8) + img = Image.fromarray(x_sample) + buf = BytesIO() + img.save(buf, format='JPEG') + buf.seek(0) + + del img, x_sample, x_sample_ddim + # don't delete x_samples, it is used in the code that called this callback + + thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf + partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) + return partial_images + +# Build and return the apropriate generator for do_mk_img +def get_image_progess_generator(req, extra_props=None): + if not req.stream_progress_updates: + def empty_callback(x_samples, i): return x_samples + return empty_callback + + thread_data.partial_x_samples = None + last_callback_time = -1 + def img_callback(x_samples, i): + nonlocal last_callback_time + + thread_data.partial_x_samples = x_samples + step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 + last_callback_time = time.time() + + progress = {"step": i, "step_time": step_time} + if extra_props is not None: + progress.update(extra_props) + + if req.stream_image_progress and i % 5 == 0: + progress['output'] = update_temp_img(req, x_samples) + + yield json.dumps(progress) + + if thread_data.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + return img_callback + def do_mk_img(req: Request): thread_data.stop_processing = False @@ -353,7 +452,7 @@ def do_mk_img(req: Request): if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') needs_model_reload = False - if thread_data.ckpt_file != req.use_stable_diffusion_model: + if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model: thread_data.ckpt_file = req.use_stable_diffusion_model needs_model_reload = True @@ -361,25 +460,19 @@ def do_mk_img(req: Request): if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): thread_data.precision = 'full' if req.use_full_precision else 'autocast' - load_model_ckpt() - load_model_gfpgan() - load_model_real_esrgan() - needs_model_reload = False + needs_model_reload = True if needs_model_reload: + unload_models() + unload_filters() load_model_ckpt() - if req.use_face_correction is not None and req.use_face_correction != thread_data.gfpgan_file: - thread_data.gfpgan_file = req.use_face_correction - load_model_gfpgan() - if req.use_upscale is not None and req.use_upscale != thread_data.real_esrgan_file: - thread_data.real_esrgan_file = req.use_upscale - load_model_real_esrgan() - if thread_data.turbo != req.turbo: thread_data.turbo = req.turbo thread_data.model.turbo = req.turbo + gc() + opt_prompt = req.prompt opt_seed = req.seed opt_n_iter = 1 @@ -432,7 +525,7 @@ def do_mk_img(req: Request): if thread_data.device != "cpu" and thread_data.precision == "autocast": mask = mask.half() - move_fs_to_cpu() + wait_move(thread_data.modelFS) # Send to CPU and wait until complete. assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(req.prompt_strength * req.num_inference_steps) @@ -450,7 +543,8 @@ def do_mk_img(req: Request): for prompts in tqdm(data, desc="data"): with precision_scope("cuda"): - thread_data.modelCS.to(thread_data.device) + if thread_data.reduced_memory: + thread_data.modelCS.to(thread_data.device) uc = None if req.guidance_scale != 1.0: uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) @@ -470,47 +564,11 @@ def do_mk_img(req: Request): else: c = thread_data.modelCS.get_learned_conditioning(prompts) - thread_data.modelFS.to(thread_data.device) + if thread_data.reduced_memory: + thread_data.modelFS.to(thread_data.device) - partial_x_samples = None - last_callback_time = -1 - def img_callback(x_samples, i): - nonlocal partial_x_samples, last_callback_time - - partial_x_samples = x_samples - - if req.stream_progress_updates: - n_steps = req.num_inference_steps if req.init_image is None else t_enc - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 - last_callback_time = time.time() - - progress = {"step": i, "total_steps": n_steps, "step_time": step_time} - - if req.stream_image_progress and i % 5 == 0: - partial_images = [] - - for i in range(batch_size): - x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img = Image.fromarray(x_sample) - buf = BytesIO() - img.save(buf, format='JPEG') - buf.seek(0) - - del img, x_sample, x_samples_ddim - # don't delete x_samples, it is used in the code that called this callback - - thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf - partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) - - progress['output'] = partial_images - - yield json.dumps(progress) - - if thread_data.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") + n_steps = req.num_inference_steps if req.init_image is None else t_enc + img_callback = get_image_progess_generator(req, {"total_steps": n_steps}) # run the handler try: @@ -520,16 +578,23 @@ def do_mk_img(req: Request): else: x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) - yield from x_samples - - x_samples = partial_x_samples + if req.stream_progress_updates: + yield from x_samples + if hasattr(thread_data, 'partial_x_samples'): + if thread_data.partial_x_samples is not None: + x_samples = thread_data.partial_x_samples + del thread_data.partial_x_samples except UserInitiatedStop: - if partial_x_samples is None: + if not hasattr(thread_data, 'partial_x_samples'): continue + if thread_data.partial_x_samples is None: + del thread_data.partial_x_samples + continue + x_samples = thread_data.partial_x_samples + del thread_data.partial_x_samples - x_samples = partial_x_samples - - print("saving images") + print("decoding images") + img_data = [None] * batch_size for i in range(batch_size): img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. @@ -538,7 +603,15 @@ def do_mk_img(req: Request): x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) - img = Image.fromarray(x_sample) + img_data[i] = x_sample + del x_samples, x_samples_ddim, x_sample + + if thread_data.reduced_memory: + wait_move(thread_data.modelFS) # Send to CPU and wait until complete. + + print("saving images") + for i in range(batch_size): + img = Image.fromarray(img_data[i]) has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) @@ -562,19 +635,18 @@ def do_mk_img(req: Request): if req.save_to_disk_path is not None: res_image_orig.path_abs = img_out_path - del img if has_filters and not thread_data.stop_processing: filters_applied = [] if req.use_face_correction: - x_sample = apply_filters('gfpgan', x_sample) + img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction) filters_applied.append(req.use_face_correction) if req.use_upscale: - x_sample = apply_filters('real_esrgan', x_sample) + img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale) filters_applied.append(req.use_upscale) if (len(filters_applied) > 0): - filtered_image = Image.fromarray(x_sample) + filtered_image = Image.fromarray(img_data[i]) filtered_img_data = img_to_base64_str(filtered_image, req.output_format) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) res.images.append(response_image) @@ -587,9 +659,10 @@ def do_mk_img(req: Request): seeds += str(opt_seed) + "," opt_seed += 1 - move_fs_to_cpu() + if thread_data.reduced_memory: + unload_filters() + del img_data gc() - del x_samples, x_samples_ddim, x_sample if thread_data.device != 'cpu': print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo') @@ -626,15 +699,7 @@ Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'} def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] - if thread_data.device != "cpu": - mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 - print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo') - thread_data.modelCS.to("cpu") - while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0: - print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo') - time.sleep(1) - print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo') - + wait_move(thread_data.modelCS) # Send to CPU and wait until complete. if sampler_name == 'ddim': thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) @@ -677,21 +742,10 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o x_T=x_T, sampler = 'ddim' ) - yield from samples_ddim -def move_fs_to_cpu(): - if thread_data.device != "cpu": - mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 - print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo') - thread_data.modelFS.to("cpu") - while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0: - print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo') - time.sleep(1) - print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo') - def gc(): - #gc.collect() + gc_collect() if thread_data.device == 'cpu': return torch.cuda.empty_cache()