mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-13 14:08:17 +02:00
Memory improvements
This commit is contained in:
parent
1442748f58
commit
ccb7a553c2
@ -9,6 +9,7 @@ import os, re
|
||||
import traceback
|
||||
import torch
|
||||
import numpy as np
|
||||
from gc import collect as gc_collect
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
from tqdm import tqdm, trange
|
||||
@ -104,6 +105,7 @@ def device_init(device_selection=None):
|
||||
thread_data.turbo = False
|
||||
thread_data.has_valid_gpu = False
|
||||
thread_data.force_full_precision = False
|
||||
thread_data.reduced_memory = True
|
||||
|
||||
if device_selection.lower() == 'cpu':
|
||||
print('CPU requested, skipping gpu init.')
|
||||
@ -172,7 +174,6 @@ def load_model_ckpt():
|
||||
if not thread_data.unet_bs:
|
||||
thread_data.unet_bs = 1
|
||||
|
||||
unload_model()
|
||||
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.precision = 'full'
|
||||
@ -213,14 +214,20 @@ def load_model_ckpt():
|
||||
modelCS.eval()
|
||||
modelCS.cond_stage_model.device = torch.device(thread_data.device)
|
||||
if thread_data.device != 'cpu':
|
||||
modelCS.to(thread_data.device)
|
||||
if thread_data.reduced_memory:
|
||||
modelCS.to('cpu')
|
||||
else:
|
||||
modelCS.to(thread_data.device) # Preload on device if not already there.
|
||||
thread_data.modelCS = modelCS
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
if thread_data.device != 'cpu':
|
||||
modelFS.to(thread_data.device)
|
||||
if thread_data.reduced_memory:
|
||||
modelFS.to('cpu')
|
||||
else:
|
||||
modelFS.to(thread_data.device) # Preload on device if not already there.
|
||||
thread_data.modelFS = modelFS
|
||||
del sd
|
||||
|
||||
@ -236,20 +243,55 @@ def load_model_ckpt():
|
||||
|
||||
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
|
||||
|
||||
def unload_model():
|
||||
def unload_filters():
|
||||
if thread_data.model_gfpgan is not None:
|
||||
del thread_data.model_gfpgan
|
||||
thread_data.model_gfpgan = None
|
||||
|
||||
if thread_data.model_real_esrgan is not None:
|
||||
del thread_data.model_real_esrgan
|
||||
thread_data.model_real_esrgan = None
|
||||
|
||||
def unload_models():
|
||||
if thread_data.model is not None:
|
||||
print('Unloading models...')
|
||||
del thread_data.model
|
||||
del thread_data.modelCS
|
||||
del thread_data.modelFS
|
||||
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
|
||||
def wait_move(model, target_device=None): # Send to target_device and wait until complete.
|
||||
if thread_data.device == "cpu" or thread_data.device == target_device: return
|
||||
if target_device is None: target_device = 'cpu'
|
||||
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
if start_mem <= 0: return
|
||||
model_name = model.__class__.__name__
|
||||
print(f'Device:{thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mo')
|
||||
model.to(target_device)
|
||||
start_time = time.time()
|
||||
time_step = start_time
|
||||
WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
|
||||
last_mem = start_mem
|
||||
is_transfering = True
|
||||
while is_transfering:
|
||||
time.sleep(0.5) # 500ms
|
||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
|
||||
last_mem = mem
|
||||
if not is_transfering:
|
||||
break;
|
||||
if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
|
||||
print(f'Device:{thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mo, Transfered: {round(start_mem - mem)}Mo')
|
||||
time_step = time.time()
|
||||
print(f'Device:{thread_data.device} - {model_name} Moved: {round(start_mem - mem)}Mo in {round(time.time() - start_time, 3)} seconds to {target_device}')
|
||||
|
||||
def load_model_gfpgan():
|
||||
if thread_data.gfpgan_file is None:
|
||||
print('load_model_gfpgan called without setting gfpgan_file')
|
||||
return
|
||||
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
|
||||
#print('load_model_gfpgan called without setting gfpgan_file')
|
||||
#return
|
||||
if not is_first_cuda_device(thread_data.device):
|
||||
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}. Cannot run GFPGANer.')
|
||||
@ -258,9 +300,9 @@ def load_model_gfpgan():
|
||||
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||
|
||||
def load_model_real_esrgan():
|
||||
if thread_data.real_esrgan_file is None:
|
||||
print('load_model_real_esrgan called without setting real_esrgan_file')
|
||||
return
|
||||
if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
|
||||
#print('load_model_real_esrgan called without setting real_esrgan_file')
|
||||
#return
|
||||
model_path = thread_data.real_esrgan_file + ".pth"
|
||||
|
||||
RealESRGAN_models = {
|
||||
@ -294,7 +336,7 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
|
||||
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
|
||||
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
|
||||
|
||||
def apply_filters(filter_name, image_data):
|
||||
def apply_filters(filter_name, image_data, model_path=None):
|
||||
print(f'Applying filter {filter_name}...')
|
||||
if isinstance(image_data, torch.Tensor):
|
||||
print(image_data)
|
||||
@ -303,12 +345,22 @@ def apply_filters(filter_name, image_data):
|
||||
gc()
|
||||
|
||||
if filter_name == 'gfpgan':
|
||||
if model_path is not None and model_path != thread_data.gfpgan_file:
|
||||
thread_data.gfpgan_file = model_path
|
||||
load_model_gfpgan()
|
||||
elif not thread_data.model_gfpgan:
|
||||
load_model_gfpgan()
|
||||
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
image_data = output[:,:,::-1]
|
||||
|
||||
if filter_name == 'real_esrgan':
|
||||
if model_path is not None and model_path != thread_data.real_esrgan_file:
|
||||
thread_data.real_esrgan_file = model_path
|
||||
load_model_real_esrgan()
|
||||
elif not thread_data.model_real_esrgan:
|
||||
load_model_real_esrgan()
|
||||
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||
@ -338,6 +390,53 @@ def mk_img(req: Request):
|
||||
"detail": str(e)
|
||||
})
|
||||
|
||||
def update_temp_img(req, x_samples):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='JPEG')
|
||||
buf.seek(0)
|
||||
|
||||
del img, x_sample, x_sample_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
return partial_images
|
||||
|
||||
# Build and return the apropriate generator for do_mk_img
|
||||
def get_image_progess_generator(req, extra_props=None):
|
||||
if not req.stream_progress_updates:
|
||||
def empty_callback(x_samples, i): return x_samples
|
||||
return empty_callback
|
||||
|
||||
thread_data.partial_x_samples = None
|
||||
last_callback_time = -1
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal last_callback_time
|
||||
|
||||
thread_data.partial_x_samples = x_samples
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "step_time": step_time}
|
||||
if extra_props is not None:
|
||||
progress.update(extra_props)
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(req, x_samples)
|
||||
|
||||
yield json.dumps(progress)
|
||||
|
||||
if thread_data.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
return img_callback
|
||||
|
||||
def do_mk_img(req: Request):
|
||||
thread_data.stop_processing = False
|
||||
|
||||
@ -353,7 +452,7 @@ def do_mk_img(req: Request):
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
|
||||
needs_model_reload = False
|
||||
if thread_data.ckpt_file != req.use_stable_diffusion_model:
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
needs_model_reload = True
|
||||
|
||||
@ -361,25 +460,19 @@ def do_mk_img(req: Request):
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
load_model_ckpt()
|
||||
load_model_gfpgan()
|
||||
load_model_real_esrgan()
|
||||
needs_model_reload = False
|
||||
needs_model_reload = True
|
||||
|
||||
if needs_model_reload:
|
||||
unload_models()
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
if req.use_face_correction is not None and req.use_face_correction != thread_data.gfpgan_file:
|
||||
thread_data.gfpgan_file = req.use_face_correction
|
||||
load_model_gfpgan()
|
||||
if req.use_upscale is not None and req.use_upscale != thread_data.real_esrgan_file:
|
||||
thread_data.real_esrgan_file = req.use_upscale
|
||||
load_model_real_esrgan()
|
||||
|
||||
if thread_data.turbo != req.turbo:
|
||||
thread_data.turbo = req.turbo
|
||||
thread_data.model.turbo = req.turbo
|
||||
|
||||
gc()
|
||||
|
||||
opt_prompt = req.prompt
|
||||
opt_seed = req.seed
|
||||
opt_n_iter = 1
|
||||
@ -432,7 +525,7 @@ def do_mk_img(req: Request):
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
mask = mask.half()
|
||||
|
||||
move_fs_to_cpu()
|
||||
wait_move(thread_data.modelFS) # Send to CPU and wait until complete.
|
||||
|
||||
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||
@ -450,7 +543,8 @@ def do_mk_img(req: Request):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
with precision_scope("cuda"):
|
||||
thread_data.modelCS.to(thread_data.device)
|
||||
if thread_data.reduced_memory:
|
||||
thread_data.modelCS.to(thread_data.device)
|
||||
uc = None
|
||||
if req.guidance_scale != 1.0:
|
||||
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
@ -470,47 +564,11 @@ def do_mk_img(req: Request):
|
||||
else:
|
||||
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
if thread_data.reduced_memory:
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
partial_x_samples = None
|
||||
last_callback_time = -1
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal partial_x_samples, last_callback_time
|
||||
|
||||
partial_x_samples = x_samples
|
||||
|
||||
if req.stream_progress_updates:
|
||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "total_steps": n_steps, "step_time": step_time}
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
partial_images = []
|
||||
|
||||
for i in range(batch_size):
|
||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='JPEG')
|
||||
buf.seek(0)
|
||||
|
||||
del img, x_sample, x_samples_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
|
||||
progress['output'] = partial_images
|
||||
|
||||
yield json.dumps(progress)
|
||||
|
||||
if thread_data.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||
img_callback = get_image_progess_generator(req, {"total_steps": n_steps})
|
||||
|
||||
# run the handler
|
||||
try:
|
||||
@ -520,16 +578,23 @@ def do_mk_img(req: Request):
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
||||
|
||||
yield from x_samples
|
||||
|
||||
x_samples = partial_x_samples
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
if hasattr(thread_data, 'partial_x_samples'):
|
||||
if thread_data.partial_x_samples is not None:
|
||||
x_samples = thread_data.partial_x_samples
|
||||
del thread_data.partial_x_samples
|
||||
except UserInitiatedStop:
|
||||
if partial_x_samples is None:
|
||||
if not hasattr(thread_data, 'partial_x_samples'):
|
||||
continue
|
||||
if thread_data.partial_x_samples is None:
|
||||
del thread_data.partial_x_samples
|
||||
continue
|
||||
x_samples = thread_data.partial_x_samples
|
||||
del thread_data.partial_x_samples
|
||||
|
||||
x_samples = partial_x_samples
|
||||
|
||||
print("saving images")
|
||||
print("decoding images")
|
||||
img_data = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
@ -538,7 +603,15 @@ def do_mk_img(req: Request):
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
img_data[i] = x_sample
|
||||
del x_samples, x_samples_ddim, x_sample
|
||||
|
||||
if thread_data.reduced_memory:
|
||||
wait_move(thread_data.modelFS) # Send to CPU and wait until complete.
|
||||
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
img = Image.fromarray(img_data[i])
|
||||
|
||||
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
|
||||
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
|
||||
@ -562,19 +635,18 @@ def do_mk_img(req: Request):
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
res_image_orig.path_abs = img_out_path
|
||||
|
||||
del img
|
||||
|
||||
if has_filters and not thread_data.stop_processing:
|
||||
filters_applied = []
|
||||
if req.use_face_correction:
|
||||
x_sample = apply_filters('gfpgan', x_sample)
|
||||
img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
|
||||
filters_applied.append(req.use_face_correction)
|
||||
if req.use_upscale:
|
||||
x_sample = apply_filters('real_esrgan', x_sample)
|
||||
img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
|
||||
filters_applied.append(req.use_upscale)
|
||||
if (len(filters_applied) > 0):
|
||||
filtered_image = Image.fromarray(x_sample)
|
||||
filtered_image = Image.fromarray(img_data[i])
|
||||
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
|
||||
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
||||
res.images.append(response_image)
|
||||
@ -587,9 +659,10 @@ def do_mk_img(req: Request):
|
||||
seeds += str(opt_seed) + ","
|
||||
opt_seed += 1
|
||||
|
||||
move_fs_to_cpu()
|
||||
if thread_data.reduced_memory:
|
||||
unload_filters()
|
||||
del img_data
|
||||
gc()
|
||||
del x_samples, x_samples_ddim, x_sample
|
||||
if thread_data.device != 'cpu':
|
||||
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
|
||||
|
||||
@ -626,15 +699,7 @@ Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if thread_data.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
|
||||
thread_data.modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
|
||||
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
|
||||
time.sleep(1)
|
||||
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
|
||||
|
||||
wait_move(thread_data.modelCS) # Send to CPU and wait until complete.
|
||||
if sampler_name == 'ddim':
|
||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
||||
@ -677,21 +742,10 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
x_T=x_T,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
yield from samples_ddim
|
||||
|
||||
def move_fs_to_cpu():
|
||||
if thread_data.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
|
||||
thread_data.modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
|
||||
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
|
||||
time.sleep(1)
|
||||
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
|
||||
|
||||
def gc():
|
||||
#gc.collect()
|
||||
gc_collect()
|
||||
if thread_data.device == 'cpu':
|
||||
return
|
||||
torch.cuda.empty_cache()
|
||||
|
Loading…
Reference in New Issue
Block a user