Image mask (inpainting)

This commit is contained in:
cmdr2
2022-09-15 17:54:03 +05:30
parent 444834a891
commit 7e7c110851
2 changed files with 96 additions and 43 deletions

View File

@ -4,7 +4,7 @@ import traceback
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from PIL import Image, ImageOps
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
from . import Request, Response, Image as ResponseImage
import base64
from io import BytesIO
from colorama import Fore
# local
stop_processing = False
@ -267,6 +268,8 @@ def mk_img(req: Request):
else:
precision_scope = nullcontext
mask = None
if req.init_image is None:
handler = _txt2img
@ -286,6 +289,14 @@ def mk_img(req: Request):
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast":
mask = mask.half()
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
@ -365,9 +376,9 @@ def mk_img(req: Request):
# run the handler
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
if req.stream_progress_updates:
yield from x_samples
@ -414,6 +425,8 @@ def mk_img(req: Request):
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
print('Applying filters..')
gc()
filters_applied = []
@ -451,6 +464,8 @@ def mk_img(req: Request):
del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print(Fore.GREEN + 'Task completed')
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
x_T=start_code,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask,
sampler = 'plms',
)
@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
else:
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
t_enc,
@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_conditioning=uc,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
image = torch.from_numpy(image)
return 2.*image - 1.
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
image = base64_str_to_img(mask_str).convert("RGB")
w, h = image.size
print(f"loaded input mask of size ({w}, {h})")
if invert:
print("inverted")
image = ImageOps.invert(image)
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img):
buffered = BytesIO()