forked from extern/easydiffusion
Image mask (inpainting)
This commit is contained in:
parent
444834a891
commit
7e7c110851
@ -294,13 +294,21 @@
|
||||
</div>
|
||||
|
||||
<div id="editor-inputs-init-image" class="row">
|
||||
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /> </button><br/>
|
||||
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /><br/>
|
||||
<div id="init_image_preview_container" class="image_preview_container">
|
||||
<img id="init_image_preview" src="" width="100" height="100" />
|
||||
<button id="init_image_clear" class="image_clear_btn">X</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="editor-inputs-mask_setting">
|
||||
<label for="mask"><b>Image Mask:</b> (optional) </label> <input id="mask" name="mask" type="file" /><br/>
|
||||
<div id="mask_preview_container" class="image_preview_container">
|
||||
<img id="mask_preview" src="" width="100" height="100" />
|
||||
<button id="mask_clear" class="image_clear_btn">X</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="editor-inputs-tags-container" class="row">
|
||||
<label>Tags: <small>(click a tag to remove it)</small></label>
|
||||
<div id="editor-inputs-tags-list">
|
||||
@ -453,8 +461,8 @@ let widthField = document.querySelector('#width')
|
||||
let heightField = document.querySelector('#height')
|
||||
let initImageSelector = document.querySelector("#init_image")
|
||||
let initImagePreview = document.querySelector("#init_image_preview")
|
||||
// let maskImageSelector = document.querySelector("#mask")
|
||||
// let maskImagePreview = document.querySelector("#mask_preview")
|
||||
let maskImageSelector = document.querySelector("#mask")
|
||||
let maskImagePreview = document.querySelector("#mask_preview")
|
||||
let turboField = document.querySelector('#turbo')
|
||||
let useCPUField = document.querySelector('#use_cpu')
|
||||
let useFullPrecisionField = document.querySelector('#use_full_precision')
|
||||
@ -479,9 +487,9 @@ let initImagePreviewContainer = document.querySelector('#init_image_preview_cont
|
||||
let initImageClearBtn = document.querySelector('#init_image_clear')
|
||||
let promptStrengthContainer = document.querySelector('#prompt_strength_container')
|
||||
|
||||
// let maskSetting = document.querySelector('#mask_setting')
|
||||
// let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
|
||||
// let maskImageClearBtn = document.querySelector('#mask_clear')
|
||||
let maskSetting = document.querySelector('#editor-inputs-mask_setting')
|
||||
let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
|
||||
let maskImageClearBtn = document.querySelector('#mask_clear')
|
||||
|
||||
let editorModifierEntries = document.querySelector('#editor-modifiers-entries')
|
||||
let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list')
|
||||
@ -844,7 +852,7 @@ async function doMakeImage(reqBody, batchCount) {
|
||||
initImagePreviewContainer.style.display = 'block'
|
||||
promptStrengthContainer.style.display = 'block'
|
||||
|
||||
// maskSetting.style.display = 'block'
|
||||
maskSetting.style.display = 'block'
|
||||
|
||||
randomSeedField.checked = false
|
||||
seedField.value = seed
|
||||
@ -949,9 +957,9 @@ async function makeImage() {
|
||||
reqBody['init_image'] = initImagePreview.src
|
||||
reqBody['prompt_strength'] = promptStrengthField.value
|
||||
|
||||
// if (IMAGE_REGEX.test(maskImagePreview.src)) {
|
||||
// reqBody['mask'] = maskImagePreview.src
|
||||
// }
|
||||
if (IMAGE_REGEX.test(maskImagePreview.src)) {
|
||||
reqBody['mask'] = maskImagePreview.src
|
||||
}
|
||||
}
|
||||
|
||||
if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
|
||||
@ -1210,7 +1218,7 @@ function showInitImagePreview() {
|
||||
if (initImageSelector.files.length === 0) {
|
||||
initImagePreviewContainer.style.display = 'none'
|
||||
promptStrengthContainer.style.display = 'none'
|
||||
// maskSetting.style.display = 'none'
|
||||
maskSetting.style.display = 'none'
|
||||
return
|
||||
}
|
||||
|
||||
@ -1223,7 +1231,7 @@ function showInitImagePreview() {
|
||||
initImagePreviewContainer.style.display = 'block'
|
||||
promptStrengthContainer.style.display = 'block'
|
||||
|
||||
// maskSetting.style.display = 'block'
|
||||
maskSetting.style.display = 'block'
|
||||
})
|
||||
|
||||
if (file) {
|
||||
@ -1235,45 +1243,45 @@ showInitImagePreview()
|
||||
|
||||
initImageClearBtn.addEventListener('click', function() {
|
||||
initImageSelector.value = null
|
||||
// maskImageSelector.value = null
|
||||
maskImageSelector.value = null
|
||||
|
||||
initImagePreview.src = ''
|
||||
// maskImagePreview.src = ''
|
||||
maskImagePreview.src = ''
|
||||
|
||||
initImagePreviewContainer.style.display = 'none'
|
||||
// maskImagePreviewContainer.style.display = 'none'
|
||||
maskImagePreviewContainer.style.display = 'none'
|
||||
|
||||
// maskSetting.style.display = 'none'
|
||||
maskSetting.style.display = 'none'
|
||||
|
||||
promptStrengthContainer.style.display = 'none'
|
||||
})
|
||||
|
||||
// function showMaskImagePreview() {
|
||||
// if (maskImageSelector.files.length === 0) {
|
||||
// maskImagePreviewContainer.style.display = 'none'
|
||||
// return
|
||||
// }
|
||||
function showMaskImagePreview() {
|
||||
if (maskImageSelector.files.length === 0) {
|
||||
maskImagePreviewContainer.style.display = 'none'
|
||||
return
|
||||
}
|
||||
|
||||
// let reader = new FileReader()
|
||||
// let file = maskImageSelector.files[0]
|
||||
let reader = new FileReader()
|
||||
let file = maskImageSelector.files[0]
|
||||
|
||||
// reader.addEventListener('load', function() {
|
||||
// maskImagePreview.src = reader.result
|
||||
// maskImagePreviewContainer.style.display = 'block'
|
||||
// })
|
||||
reader.addEventListener('load', function() {
|
||||
maskImagePreview.src = reader.result
|
||||
maskImagePreviewContainer.style.display = 'block'
|
||||
})
|
||||
|
||||
// if (file) {
|
||||
// reader.readAsDataURL(file)
|
||||
// }
|
||||
// }
|
||||
// maskImageSelector.addEventListener('change', showMaskImagePreview)
|
||||
// showMaskImagePreview()
|
||||
if (file) {
|
||||
reader.readAsDataURL(file)
|
||||
}
|
||||
}
|
||||
maskImageSelector.addEventListener('change', showMaskImagePreview)
|
||||
showMaskImagePreview()
|
||||
|
||||
// maskImageClearBtn.addEventListener('click', function() {
|
||||
// maskImageSelector.value = null
|
||||
// maskImagePreview.src = ''
|
||||
// maskImagePreviewContainer.style.display = 'none'
|
||||
// })
|
||||
maskImageClearBtn.addEventListener('click', function() {
|
||||
maskImageSelector.value = null
|
||||
maskImagePreview.src = ''
|
||||
maskImagePreviewContainer.style.display = 'none'
|
||||
})
|
||||
|
||||
// https://stackoverflow.com/a/8212878
|
||||
function millisecondsToStr(milliseconds) {
|
||||
|
@ -4,7 +4,7 @@ import traceback
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageOps
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
from . import Request, Response, Image as ResponseImage
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from colorama import Fore
|
||||
|
||||
# local
|
||||
stop_processing = False
|
||||
@ -267,6 +268,8 @@ def mk_img(req: Request):
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
mask = None
|
||||
|
||||
if req.init_image is None:
|
||||
handler = _txt2img
|
||||
|
||||
@ -286,6 +289,14 @@ def mk_img(req: Request):
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if req.mask is not None:
|
||||
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
mask = mask.half()
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
@ -365,9 +376,9 @@ def mk_img(req: Request):
|
||||
# run the handler
|
||||
try:
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
@ -414,6 +425,8 @@ def mk_img(req: Request):
|
||||
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
||||
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
|
||||
|
||||
print('Applying filters..')
|
||||
|
||||
gc()
|
||||
filters_applied = []
|
||||
|
||||
@ -451,6 +464,8 @@ def mk_img(req: Request):
|
||||
del x_samples
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
|
||||
print(Fore.GREEN + 'Task completed')
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield json.dumps(res.json())
|
||||
else:
|
||||
@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
mask=mask,
|
||||
sampler = 'plms',
|
||||
)
|
||||
|
||||
@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
else:
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
opt_ddim_eta,
|
||||
opt_ddim_steps,
|
||||
)
|
||||
x_T = None if mask is None else init_latent
|
||||
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
t_enc,
|
||||
@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
mask=mask,
|
||||
x_T=x_T,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
|
||||
image = torch.from_numpy(image)
|
||||
return 2.*image - 1.
|
||||
|
||||
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
||||
image = base64_str_to_img(mask_str).convert("RGB")
|
||||
w, h = image.size
|
||||
print(f"loaded input mask of size ({w}, {h})")
|
||||
|
||||
if invert:
|
||||
print("inverted")
|
||||
image = ImageOps.invert(image)
|
||||
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
|
||||
# image[where_0], image[where_1] = 255, 0
|
||||
|
||||
if h0 is not None and w0 is not None:
|
||||
h, w = h0, w0
|
||||
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
|
||||
print(f"New mask size ({w}, {h})")
|
||||
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
|
||||
image = np.array(image)
|
||||
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image
|
||||
|
||||
# https://stackoverflow.com/a/61114178
|
||||
def img_to_base64_str(img):
|
||||
buffered = BytesIO()
|
||||
|
Loading…
Reference in New Issue
Block a user