Image mask (inpainting)

This commit is contained in:
cmdr2 2022-09-15 17:54:03 +05:30
parent 444834a891
commit 7e7c110851
2 changed files with 96 additions and 43 deletions

View File

@ -294,13 +294,21 @@
</div> </div>
<div id="editor-inputs-init-image" class="row"> <div id="editor-inputs-init-image" class="row">
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /> </button><br/> <label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /><br/>
<div id="init_image_preview_container" class="image_preview_container"> <div id="init_image_preview_container" class="image_preview_container">
<img id="init_image_preview" src="" width="100" height="100" /> <img id="init_image_preview" src="" width="100" height="100" />
<button id="init_image_clear" class="image_clear_btn">X</button> <button id="init_image_clear" class="image_clear_btn">X</button>
</div> </div>
</div> </div>
<div id="editor-inputs-mask_setting">
<label for="mask"><b>Image Mask:</b> (optional) </label> <input id="mask" name="mask" type="file" /><br/>
<div id="mask_preview_container" class="image_preview_container">
<img id="mask_preview" src="" width="100" height="100" />
<button id="mask_clear" class="image_clear_btn">X</button>
</div>
</div>
<div id="editor-inputs-tags-container" class="row"> <div id="editor-inputs-tags-container" class="row">
<label>Tags: <small>(click a tag to remove it)</small></label> <label>Tags: <small>(click a tag to remove it)</small></label>
<div id="editor-inputs-tags-list"> <div id="editor-inputs-tags-list">
@ -453,8 +461,8 @@ let widthField = document.querySelector('#width')
let heightField = document.querySelector('#height') let heightField = document.querySelector('#height')
let initImageSelector = document.querySelector("#init_image") let initImageSelector = document.querySelector("#init_image")
let initImagePreview = document.querySelector("#init_image_preview") let initImagePreview = document.querySelector("#init_image_preview")
// let maskImageSelector = document.querySelector("#mask") let maskImageSelector = document.querySelector("#mask")
// let maskImagePreview = document.querySelector("#mask_preview") let maskImagePreview = document.querySelector("#mask_preview")
let turboField = document.querySelector('#turbo') let turboField = document.querySelector('#turbo')
let useCPUField = document.querySelector('#use_cpu') let useCPUField = document.querySelector('#use_cpu')
let useFullPrecisionField = document.querySelector('#use_full_precision') let useFullPrecisionField = document.querySelector('#use_full_precision')
@ -479,9 +487,9 @@ let initImagePreviewContainer = document.querySelector('#init_image_preview_cont
let initImageClearBtn = document.querySelector('#init_image_clear') let initImageClearBtn = document.querySelector('#init_image_clear')
let promptStrengthContainer = document.querySelector('#prompt_strength_container') let promptStrengthContainer = document.querySelector('#prompt_strength_container')
// let maskSetting = document.querySelector('#mask_setting') let maskSetting = document.querySelector('#editor-inputs-mask_setting')
// let maskImagePreviewContainer = document.querySelector('#mask_preview_container') let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
// let maskImageClearBtn = document.querySelector('#mask_clear') let maskImageClearBtn = document.querySelector('#mask_clear')
let editorModifierEntries = document.querySelector('#editor-modifiers-entries') let editorModifierEntries = document.querySelector('#editor-modifiers-entries')
let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list') let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list')
@ -844,7 +852,7 @@ async function doMakeImage(reqBody, batchCount) {
initImagePreviewContainer.style.display = 'block' initImagePreviewContainer.style.display = 'block'
promptStrengthContainer.style.display = 'block' promptStrengthContainer.style.display = 'block'
// maskSetting.style.display = 'block' maskSetting.style.display = 'block'
randomSeedField.checked = false randomSeedField.checked = false
seedField.value = seed seedField.value = seed
@ -949,9 +957,9 @@ async function makeImage() {
reqBody['init_image'] = initImagePreview.src reqBody['init_image'] = initImagePreview.src
reqBody['prompt_strength'] = promptStrengthField.value reqBody['prompt_strength'] = promptStrengthField.value
// if (IMAGE_REGEX.test(maskImagePreview.src)) { if (IMAGE_REGEX.test(maskImagePreview.src)) {
// reqBody['mask'] = maskImagePreview.src reqBody['mask'] = maskImagePreview.src
// } }
} }
if (saveToDiskField.checked && diskPathField.value.trim() !== '') { if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
@ -1210,7 +1218,7 @@ function showInitImagePreview() {
if (initImageSelector.files.length === 0) { if (initImageSelector.files.length === 0) {
initImagePreviewContainer.style.display = 'none' initImagePreviewContainer.style.display = 'none'
promptStrengthContainer.style.display = 'none' promptStrengthContainer.style.display = 'none'
// maskSetting.style.display = 'none' maskSetting.style.display = 'none'
return return
} }
@ -1223,7 +1231,7 @@ function showInitImagePreview() {
initImagePreviewContainer.style.display = 'block' initImagePreviewContainer.style.display = 'block'
promptStrengthContainer.style.display = 'block' promptStrengthContainer.style.display = 'block'
// maskSetting.style.display = 'block' maskSetting.style.display = 'block'
}) })
if (file) { if (file) {
@ -1235,45 +1243,45 @@ showInitImagePreview()
initImageClearBtn.addEventListener('click', function() { initImageClearBtn.addEventListener('click', function() {
initImageSelector.value = null initImageSelector.value = null
// maskImageSelector.value = null maskImageSelector.value = null
initImagePreview.src = '' initImagePreview.src = ''
// maskImagePreview.src = '' maskImagePreview.src = ''
initImagePreviewContainer.style.display = 'none' initImagePreviewContainer.style.display = 'none'
// maskImagePreviewContainer.style.display = 'none' maskImagePreviewContainer.style.display = 'none'
// maskSetting.style.display = 'none' maskSetting.style.display = 'none'
promptStrengthContainer.style.display = 'none' promptStrengthContainer.style.display = 'none'
}) })
// function showMaskImagePreview() { function showMaskImagePreview() {
// if (maskImageSelector.files.length === 0) { if (maskImageSelector.files.length === 0) {
// maskImagePreviewContainer.style.display = 'none' maskImagePreviewContainer.style.display = 'none'
// return return
// } }
// let reader = new FileReader() let reader = new FileReader()
// let file = maskImageSelector.files[0] let file = maskImageSelector.files[0]
// reader.addEventListener('load', function() { reader.addEventListener('load', function() {
// maskImagePreview.src = reader.result maskImagePreview.src = reader.result
// maskImagePreviewContainer.style.display = 'block' maskImagePreviewContainer.style.display = 'block'
// }) })
// if (file) { if (file) {
// reader.readAsDataURL(file) reader.readAsDataURL(file)
// } }
// } }
// maskImageSelector.addEventListener('change', showMaskImagePreview) maskImageSelector.addEventListener('change', showMaskImagePreview)
// showMaskImagePreview() showMaskImagePreview()
// maskImageClearBtn.addEventListener('click', function() { maskImageClearBtn.addEventListener('click', function() {
// maskImageSelector.value = null maskImageSelector.value = null
// maskImagePreview.src = '' maskImagePreview.src = ''
// maskImagePreviewContainer.style.display = 'none' maskImagePreviewContainer.style.display = 'none'
// }) })
// https://stackoverflow.com/a/8212878 // https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) { function millisecondsToStr(milliseconds) {

View File

@ -4,7 +4,7 @@ import traceback
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image, ImageOps
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange from einops import rearrange
@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
from . import Request, Response, Image as ResponseImage from . import Request, Response, Image as ResponseImage
import base64 import base64
from io import BytesIO from io import BytesIO
from colorama import Fore
# local # local
stop_processing = False stop_processing = False
@ -267,6 +268,8 @@ def mk_img(req: Request):
else: else:
precision_scope = nullcontext precision_scope = nullcontext
mask = None
if req.init_image is None: if req.init_image is None:
handler = _txt2img handler = _txt2img
@ -286,6 +289,14 @@ def mk_img(req: Request):
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast":
mask = mask.half()
if device != "cpu": if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6 mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu") modelFS.to("cpu")
@ -365,9 +376,9 @@ def mk_img(req: Request):
# run the handler # run the handler
try: try:
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates) x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates) x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
if req.stream_progress_updates: if req.stream_progress_updates:
yield from x_samples yield from x_samples
@ -414,6 +425,8 @@ def mk_img(req: Request):
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \ if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')): (opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
print('Applying filters..')
gc() gc()
filters_applied = [] filters_applied = []
@ -451,6 +464,8 @@ def mk_img(req: Request):
del x_samples del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6) print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print(Fore.GREEN + 'Task completed')
if req.stream_progress_updates: if req.stream_progress_updates:
yield json.dumps(res.json()) yield json.dumps(res.json())
else: else:
@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except: except:
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks): def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu": if device != "cpu":
@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
x_T=start_code, x_T=start_code,
img_callback=img_callback, img_callback=img_callback,
streaming_callbacks=streaming_callbacks, streaming_callbacks=streaming_callbacks,
mask=mask,
sampler = 'plms', sampler = 'plms',
) )
@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
else: else:
return samples_ddim return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
# encode (scaled latent) # encode (scaled latent)
z_enc = model.stochastic_encode( z_enc = model.stochastic_encode(
init_latent, init_latent,
@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
opt_ddim_eta, opt_ddim_eta,
opt_ddim_steps, opt_ddim_steps,
) )
x_T = None if mask is None else init_latent
# decode it # decode it
samples_ddim = model.sample( samples_ddim = model.sample(
t_enc, t_enc,
@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_conditioning=uc, unconditional_conditioning=uc,
img_callback=img_callback, img_callback=img_callback,
streaming_callbacks=streaming_callbacks, streaming_callbacks=streaming_callbacks,
mask=mask,
x_T=x_T,
sampler = 'ddim' sampler = 'ddim'
) )
@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.*image - 1.
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
image = base64_str_to_img(mask_str).convert("RGB")
w, h = image.size
print(f"loaded input mask of size ({w}, {h})")
if invert:
print("inverted")
image = ImageOps.invert(image)
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
# https://stackoverflow.com/a/61114178 # https://stackoverflow.com/a/61114178
def img_to_base64_str(img): def img_to_base64_str(img):
buffered = BytesIO() buffered = BytesIO()