forked from extern/easydiffusion
Image mask (inpainting)
This commit is contained in:
@ -4,7 +4,7 @@ import traceback
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
from . import Request, Response, Image as ResponseImage
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from colorama import Fore
|
||||
|
||||
# local
|
||||
stop_processing = False
|
||||
@ -267,6 +268,8 @@ def mk_img(req: Request):
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
mask = None
|
||||
|
||||
if req.init_image is None:
|
||||
handler = _txt2img
|
||||
|
||||
@ -286,6 +289,14 @@ def mk_img(req: Request):
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if req.mask is not None:
|
||||
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
mask = mask.half()
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
@ -365,9 +376,9 @@ def mk_img(req: Request):
|
||||
# run the handler
|
||||
try:
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
@ -414,6 +425,8 @@ def mk_img(req: Request):
|
||||
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
||||
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
|
||||
|
||||
print('Applying filters..')
|
||||
|
||||
gc()
|
||||
filters_applied = []
|
||||
|
||||
@ -451,6 +464,8 @@ def mk_img(req: Request):
|
||||
del x_samples
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
|
||||
print(Fore.GREEN + 'Task completed')
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield json.dumps(res.json())
|
||||
else:
|
||||
@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
mask=mask,
|
||||
sampler = 'plms',
|
||||
)
|
||||
|
||||
@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
else:
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
opt_ddim_eta,
|
||||
opt_ddim_steps,
|
||||
)
|
||||
x_T = None if mask is None else init_latent
|
||||
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
t_enc,
|
||||
@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
mask=mask,
|
||||
x_T=x_T,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
|
||||
image = torch.from_numpy(image)
|
||||
return 2.*image - 1.
|
||||
|
||||
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
||||
image = base64_str_to_img(mask_str).convert("RGB")
|
||||
w, h = image.size
|
||||
print(f"loaded input mask of size ({w}, {h})")
|
||||
|
||||
if invert:
|
||||
print("inverted")
|
||||
image = ImageOps.invert(image)
|
||||
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
|
||||
# image[where_0], image[where_1] = 255, 0
|
||||
|
||||
if h0 is not None and w0 is not None:
|
||||
h, w = h0, w0
|
||||
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
|
||||
print(f"New mask size ({w}, {h})")
|
||||
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
|
||||
image = np.array(image)
|
||||
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image
|
||||
|
||||
# https://stackoverflow.com/a/61114178
|
||||
def img_to_base64_str(img):
|
||||
buffered = BytesIO()
|
||||
|
Reference in New Issue
Block a user