mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-30 14:49:55 +02:00
Runtime cleanup and moved apply_filters to it's own function
This commit is contained in:
parent
6ae3b77c2f
commit
c7f6763c48
@ -197,6 +197,35 @@ def load_model_real_esrgan(real_esrgan_to_use):
|
|||||||
|
|
||||||
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
||||||
|
|
||||||
|
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
||||||
|
if disk_path is None: return None
|
||||||
|
if session_id is None: return None
|
||||||
|
if ext is None: raise Exception('Missing ext')
|
||||||
|
|
||||||
|
session_out_path = os.path.join(disk_path, session_id)
|
||||||
|
os.makedirs(session_out_path, exist_ok=True)
|
||||||
|
|
||||||
|
prompt_flattened = filename_regex.sub('_', prompt)[:50]
|
||||||
|
img_id = str(uuid.uuid4())[-8:]
|
||||||
|
|
||||||
|
if suffix is not None:
|
||||||
|
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
|
||||||
|
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
|
||||||
|
|
||||||
|
def apply_filters(filter_name, image_data):
|
||||||
|
print(f'Applying filter {filter_name}...')
|
||||||
|
gc()
|
||||||
|
|
||||||
|
if filter_name == 'gfpgan':
|
||||||
|
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
|
image_data = output[:,:,::-1]
|
||||||
|
|
||||||
|
if filter_name == 'real_esrgan':
|
||||||
|
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||||
|
image_data = output[:,:,::-1]
|
||||||
|
|
||||||
|
return image_data
|
||||||
|
|
||||||
def mk_img(req: Request):
|
def mk_img(req: Request):
|
||||||
try:
|
try:
|
||||||
yield from do_mk_img(req)
|
yield from do_mk_img(req)
|
||||||
@ -283,23 +312,11 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
opt_prompt = req.prompt
|
opt_prompt = req.prompt
|
||||||
opt_seed = req.seed
|
opt_seed = req.seed
|
||||||
opt_n_samples = req.num_outputs
|
|
||||||
opt_n_iter = 1
|
opt_n_iter = 1
|
||||||
opt_scale = req.guidance_scale
|
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_H = req.height
|
|
||||||
opt_W = req.width
|
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
opt_ddim_steps = req.num_inference_steps
|
|
||||||
opt_ddim_eta = 0.0
|
opt_ddim_eta = 0.0
|
||||||
opt_strength = req.prompt_strength
|
|
||||||
opt_save_to_disk_path = req.save_to_disk_path
|
|
||||||
opt_init_img = req.init_image
|
opt_init_img = req.init_image
|
||||||
opt_use_face_correction = req.use_face_correction
|
|
||||||
opt_use_upscale = req.use_upscale
|
|
||||||
opt_show_only_filtered = req.show_only_filtered_image
|
|
||||||
opt_format = req.output_format
|
|
||||||
opt_sampler_name = req.sampler
|
|
||||||
|
|
||||||
print(req.to_string(), '\n device', device)
|
print(req.to_string(), '\n device', device)
|
||||||
|
|
||||||
@ -307,7 +324,7 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
seed_everything(opt_seed)
|
seed_everything(opt_seed)
|
||||||
|
|
||||||
batch_size = opt_n_samples
|
batch_size = req.num_outputs
|
||||||
prompt = opt_prompt
|
prompt = opt_prompt
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
@ -327,7 +344,7 @@ def do_mk_img(req: Request):
|
|||||||
else:
|
else:
|
||||||
handler = _img2img
|
handler = _img2img
|
||||||
|
|
||||||
init_image = load_img(req.init_image, opt_W, opt_H)
|
init_image = load_img(req.init_image, req.width, req.height)
|
||||||
init_image = init_image.to(device)
|
init_image = init_image.to(device)
|
||||||
|
|
||||||
if device != "cpu" and precision == "autocast":
|
if device != "cpu" and precision == "autocast":
|
||||||
@ -339,7 +356,7 @@ def do_mk_img(req: Request):
|
|||||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
if req.mask is not None:
|
if req.mask is not None:
|
||||||
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||||
|
|
||||||
@ -348,12 +365,12 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
move_fs_to_cpu()
|
move_fs_to_cpu()
|
||||||
|
|
||||||
assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
t_enc = int(opt_strength * opt_ddim_steps)
|
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||||
print(f"target t_enc is {t_enc} steps")
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
|
session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
|
||||||
os.makedirs(session_out_path, exist_ok=True)
|
os.makedirs(session_out_path, exist_ok=True)
|
||||||
else:
|
else:
|
||||||
session_out_path = None
|
session_out_path = None
|
||||||
@ -366,7 +383,7 @@ def do_mk_img(req: Request):
|
|||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
modelCS.to(device)
|
modelCS.to(device)
|
||||||
uc = None
|
uc = None
|
||||||
if opt_scale != 1.0:
|
if req.guidance_scale != 1.0:
|
||||||
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
@ -393,7 +410,7 @@ def do_mk_img(req: Request):
|
|||||||
partial_x_samples = x_samples
|
partial_x_samples = x_samples
|
||||||
|
|
||||||
if req.stream_progress_updates:
|
if req.stream_progress_updates:
|
||||||
n_steps = opt_ddim_steps if req.init_image is None else t_enc
|
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||||
progress = {"step": i, "total_steps": n_steps}
|
progress = {"step": i, "total_steps": n_steps}
|
||||||
|
|
||||||
if req.stream_image_progress and i % 5 == 0:
|
if req.stream_image_progress and i % 5 == 0:
|
||||||
@ -425,9 +442,9 @@ def do_mk_img(req: Request):
|
|||||||
# run the handler
|
# run the handler
|
||||||
try:
|
try:
|
||||||
if handler == _txt2img:
|
if handler == _txt2img:
|
||||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
|
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||||
else:
|
else:
|
||||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
||||||
|
|
||||||
yield from x_samples
|
yield from x_samples
|
||||||
|
|
||||||
@ -447,69 +464,49 @@ def do_mk_img(req: Request):
|
|||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
img = Image.fromarray(x_sample)
|
img = Image.fromarray(x_sample)
|
||||||
|
|
||||||
has_filters = (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
|
||||||
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN'))
|
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
|
||||||
|
|
||||||
return_orig_img = not has_filters or not opt_show_only_filtered
|
return_orig_img = not has_filters or not req.show_only_filtered_image
|
||||||
|
|
||||||
if stop_processing:
|
if stop_processing:
|
||||||
return_orig_img = True
|
return_orig_img = True
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
prompt_flattened = filename_regex.sub('_', prompts[0])
|
|
||||||
prompt_flattened = prompt_flattened[:50]
|
|
||||||
|
|
||||||
img_id = str(uuid.uuid4())[-8:]
|
|
||||||
|
|
||||||
file_path = f"{prompt_flattened}_{img_id}"
|
|
||||||
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
|
|
||||||
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
|
|
||||||
|
|
||||||
if return_orig_img:
|
if return_orig_img:
|
||||||
|
img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format)
|
||||||
save_image(img, img_out_path)
|
save_image(img, img_out_path)
|
||||||
|
meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], 'txt')
|
||||||
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name, req.negative_prompt, ckpt_file)
|
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
||||||
|
|
||||||
if return_orig_img:
|
if return_orig_img:
|
||||||
img_data = img_to_base64_str(img, opt_format)
|
img_data = img_to_base64_str(img, req.output_format)
|
||||||
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
|
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
|
||||||
res.images.append(res_image_orig)
|
res.images.append(res_image_orig)
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
res_image_orig.path_abs = img_out_path
|
res_image_orig.path_abs = img_out_path
|
||||||
|
|
||||||
del img
|
del img
|
||||||
|
|
||||||
if has_filters and not stop_processing:
|
if has_filters and not stop_processing:
|
||||||
print('Applying filters..')
|
|
||||||
|
|
||||||
gc()
|
|
||||||
filters_applied = []
|
filters_applied = []
|
||||||
|
if req.use_face_correction:
|
||||||
if opt_use_face_correction:
|
x_sample = apply_filters('gfpgan', x_sample)
|
||||||
_, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
filters_applied.append(req.use_face_correction)
|
||||||
x_sample = output[:,:,::-1]
|
if req.use_upscale:
|
||||||
filters_applied.append(opt_use_face_correction)
|
x_sample = apply_filters('real_esrgan', x_sample)
|
||||||
|
filters_applied.append(req.use_upscale)
|
||||||
if opt_use_upscale:
|
if (len(filters_applied) > 0):
|
||||||
output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1])
|
filtered_image = Image.fromarray(x_sample)
|
||||||
x_sample = output[:,:,::-1]
|
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
|
||||||
filters_applied.append(opt_use_upscale)
|
response_image = ResponseImage(data=filtered_img_data, seed=req.seed)
|
||||||
|
res.images.append(response_image)
|
||||||
filtered_image = Image.fromarray(x_sample)
|
if req.save_to_disk_path is not None:
|
||||||
|
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], req.output_format, "_".join(filters_applied))
|
||||||
filtered_img_data = img_to_base64_str(filtered_image, opt_format)
|
save_image(filtered_image, filtered_img_out_path)
|
||||||
res_image_filtered = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
response_image.path_abs = filtered_img_out_path
|
||||||
res.images.append(res_image_filtered)
|
del filtered_image
|
||||||
|
|
||||||
filters_applied = "_".join(filters_applied)
|
|
||||||
|
|
||||||
if opt_save_to_disk_path is not None:
|
|
||||||
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
|
|
||||||
save_image(filtered_image, filtered_img_out_path)
|
|
||||||
res_image_filtered.path_abs = filtered_img_out_path
|
|
||||||
|
|
||||||
del filtered_image
|
|
||||||
|
|
||||||
seeds += str(opt_seed) + ","
|
seeds += str(opt_seed) + ","
|
||||||
opt_seed += 1
|
opt_seed += 1
|
||||||
@ -529,9 +526,20 @@ def save_image(img, img_out_path):
|
|||||||
except:
|
except:
|
||||||
print('could not save the file', traceback.format_exc())
|
print('could not save the file', traceback.format_exc())
|
||||||
|
|
||||||
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name, negative_prompt, ckpt_file):
|
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
||||||
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}\nStable Diffusion Model: {ckpt_file + '.ckpt'}"
|
metadata = f"""{prompt}
|
||||||
|
Width: {req.width}
|
||||||
|
Height: {req.height}
|
||||||
|
Seed: {opt_seed}
|
||||||
|
Steps: {req.num_inference_steps}
|
||||||
|
Guidance Scale: {req.guidance_scale}
|
||||||
|
Prompt Strength: {req.prompt_strength}
|
||||||
|
Use Face Correction: {req.use_face_correction}
|
||||||
|
Use Upscaling: {req.use_upscale}
|
||||||
|
Sampler: {req.sampler}
|
||||||
|
Negative Prompt: {req.negative_prompt}
|
||||||
|
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with open(meta_out_path, 'w') as f:
|
with open(meta_out_path, 'w') as f:
|
||||||
f.write(metadata)
|
f.write(metadata)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user