Show the progress percentage while generating images

This commit is contained in:
cmdr2
2022-09-14 16:52:03 +05:30
parent 3b47eb3b07
commit 9f48d5e5ff
5 changed files with 192 additions and 26 deletions

View File

@ -1,3 +1,4 @@
import json
import os, re
import traceback
import torch
@ -332,15 +333,23 @@ def mk_img(req: Request):
partial_x_samples = x_samples
if req.stream_progress_updates:
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
if req.stream_progress_updates:
yield from x_samples
x_samples = partial_x_samples
except UserInitiatedStop:
if partial_x_samples is None:
continue
@ -421,7 +430,10 @@ def mk_img(req: Request):
del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
return res
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
return res
def save_image(img, img_out_path):
try:
@ -438,7 +450,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -458,12 +470,16 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
sampler = 'plms',
)
return samples_ddim
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -480,10 +496,14 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
sampler = 'ddim'
)
return samples_ddim
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim
def gc():
if device == 'cpu':