Live preview of images

This commit is contained in:
cmdr2
2022-09-14 22:29:42 +05:30
parent 1d88a5b42e
commit 27071cfa29
4 changed files with 102 additions and 18 deletions

View File

@ -35,8 +35,8 @@ import base64
from io import BytesIO
# local
session_id = str(uuid.uuid4())[-8:]
stop_processing = False
temp_images = {}
ckpt_file = None
gfpgan_file = None
@ -192,10 +192,11 @@ def mk_img(req: Request):
stop_processing = False
res = Response()
res.session_id = session_id
res.request = req
res.images = []
temp_images.clear()
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
@ -296,7 +297,7 @@ def mk_img(req: Request):
print(f"target t_enc is {t_enc} steps")
if opt_save_to_disk_path is not None:
session_out_path = os.path.join(opt_save_to_disk_path, session_id)
session_out_path = os.path.join(opt_save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True)
else:
session_out_path = None
@ -327,6 +328,8 @@ def mk_img(req: Request):
else:
c = modelCS.get_learned_conditioning(prompts)
modelFS.to(device)
partial_x_samples = None
def img_callback(x_samples, i):
nonlocal partial_x_samples
@ -334,7 +337,27 @@ def mk_img(req: Request):
partial_x_samples = x_samples
if req.stream_progress_updates:
yield json.dumps({"step": i, "total_steps": opt_ddim_steps})
progress = {"step": i, "total_steps": opt_ddim_steps}
if req.stream_image_progress:
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
@ -356,8 +379,6 @@ def mk_img(req: Request):
x_samples = partial_x_samples
modelFS.to(device)
print("saving images")
for i in range(batch_size):