Image mask (inpainting)

This commit is contained in:
cmdr2 2022-09-15 17:54:03 +05:30
parent 444834a891
commit 7e7c110851
2 changed files with 96 additions and 43 deletions

View File

@ -294,13 +294,21 @@
</div>
<div id="editor-inputs-init-image" class="row">
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /> </button><br/>
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /><br/>
<div id="init_image_preview_container" class="image_preview_container">
<img id="init_image_preview" src="" width="100" height="100" />
<button id="init_image_clear" class="image_clear_btn">X</button>
</div>
</div>
<div id="editor-inputs-mask_setting">
<label for="mask"><b>Image Mask:</b> (optional) </label> <input id="mask" name="mask" type="file" /><br/>
<div id="mask_preview_container" class="image_preview_container">
<img id="mask_preview" src="" width="100" height="100" />
<button id="mask_clear" class="image_clear_btn">X</button>
</div>
</div>
<div id="editor-inputs-tags-container" class="row">
<label>Tags: <small>(click a tag to remove it)</small></label>
<div id="editor-inputs-tags-list">
@ -453,8 +461,8 @@ let widthField = document.querySelector('#width')
let heightField = document.querySelector('#height')
let initImageSelector = document.querySelector("#init_image")
let initImagePreview = document.querySelector("#init_image_preview")
// let maskImageSelector = document.querySelector("#mask")
// let maskImagePreview = document.querySelector("#mask_preview")
let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview")
let turboField = document.querySelector('#turbo')
let useCPUField = document.querySelector('#use_cpu')
let useFullPrecisionField = document.querySelector('#use_full_precision')
@ -479,9 +487,9 @@ let initImagePreviewContainer = document.querySelector('#init_image_preview_cont
let initImageClearBtn = document.querySelector('#init_image_clear')
let promptStrengthContainer = document.querySelector('#prompt_strength_container')
// let maskSetting = document.querySelector('#mask_setting')
// let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
// let maskImageClearBtn = document.querySelector('#mask_clear')
let maskSetting = document.querySelector('#editor-inputs-mask_setting')
let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
let maskImageClearBtn = document.querySelector('#mask_clear')
let editorModifierEntries = document.querySelector('#editor-modifiers-entries')
let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list')
@ -844,7 +852,7 @@ async function doMakeImage(reqBody, batchCount) {
initImagePreviewContainer.style.display = 'block'
promptStrengthContainer.style.display = 'block'
// maskSetting.style.display = 'block'
maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
@ -949,9 +957,9 @@ async function makeImage() {
reqBody['init_image'] = initImagePreview.src
reqBody['prompt_strength'] = promptStrengthField.value
// if (IMAGE_REGEX.test(maskImagePreview.src)) {
// reqBody['mask'] = maskImagePreview.src
// }
if (IMAGE_REGEX.test(maskImagePreview.src)) {
reqBody['mask'] = maskImagePreview.src
}
}
if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
@ -1210,7 +1218,7 @@ function showInitImagePreview() {
if (initImageSelector.files.length === 0) {
initImagePreviewContainer.style.display = 'none'
promptStrengthContainer.style.display = 'none'
// maskSetting.style.display = 'none'
maskSetting.style.display = 'none'
return
}
@ -1223,7 +1231,7 @@ function showInitImagePreview() {
initImagePreviewContainer.style.display = 'block'
promptStrengthContainer.style.display = 'block'
// maskSetting.style.display = 'block'
maskSetting.style.display = 'block'
})
if (file) {
@ -1235,45 +1243,45 @@ showInitImagePreview()
initImageClearBtn.addEventListener('click', function() {
initImageSelector.value = null
// maskImageSelector.value = null
maskImageSelector.value = null
initImagePreview.src = ''
// maskImagePreview.src = ''
maskImagePreview.src = ''
initImagePreviewContainer.style.display = 'none'
// maskImagePreviewContainer.style.display = 'none'
maskImagePreviewContainer.style.display = 'none'
// maskSetting.style.display = 'none'
maskSetting.style.display = 'none'
promptStrengthContainer.style.display = 'none'
})
// function showMaskImagePreview() {
// if (maskImageSelector.files.length === 0) {
// maskImagePreviewContainer.style.display = 'none'
// return
// }
function showMaskImagePreview() {
if (maskImageSelector.files.length === 0) {
maskImagePreviewContainer.style.display = 'none'
return
}
// let reader = new FileReader()
// let file = maskImageSelector.files[0]
let reader = new FileReader()
let file = maskImageSelector.files[0]
// reader.addEventListener('load', function() {
// maskImagePreview.src = reader.result
// maskImagePreviewContainer.style.display = 'block'
// })
reader.addEventListener('load', function() {
maskImagePreview.src = reader.result
maskImagePreviewContainer.style.display = 'block'
})
// if (file) {
// reader.readAsDataURL(file)
// }
// }
// maskImageSelector.addEventListener('change', showMaskImagePreview)
// showMaskImagePreview()
if (file) {
reader.readAsDataURL(file)
}
}
maskImageSelector.addEventListener('change', showMaskImagePreview)
showMaskImagePreview()
// maskImageClearBtn.addEventListener('click', function() {
// maskImageSelector.value = null
// maskImagePreview.src = ''
// maskImagePreviewContainer.style.display = 'none'
// })
maskImageClearBtn.addEventListener('click', function() {
maskImageSelector.value = null
maskImagePreview.src = ''
maskImagePreviewContainer.style.display = 'none'
})
// https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) {

View File

@ -4,7 +4,7 @@ import traceback
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from PIL import Image, ImageOps
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
from . import Request, Response, Image as ResponseImage
import base64
from io import BytesIO
from colorama import Fore
# local
stop_processing = False
@ -267,6 +268,8 @@ def mk_img(req: Request):
else:
precision_scope = nullcontext
mask = None
if req.init_image is None:
handler = _txt2img
@ -286,6 +289,14 @@ def mk_img(req: Request):
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast":
mask = mask.half()
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
@ -365,9 +376,9 @@ def mk_img(req: Request):
# run the handler
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
if req.stream_progress_updates:
yield from x_samples
@ -414,6 +425,8 @@ def mk_img(req: Request):
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
print('Applying filters..')
gc()
filters_applied = []
@ -451,6 +464,8 @@ def mk_img(req: Request):
del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print(Fore.GREEN + 'Task completed')
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
x_T=start_code,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask,
sampler = 'plms',
)
@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
else:
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
t_enc,
@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_conditioning=uc,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
image = torch.from_numpy(image)
return 2.*image - 1.
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
image = base64_str_to_img(mask_str).convert("RGB")
w, h = image.size
print(f"loaded input mask of size ({w}, {h})")
if invert:
print("inverted")
image = ImageOps.invert(image)
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img):
buffered = BytesIO()