forked from extern/easydiffusion
Image mask (inpainting)
This commit is contained in:
parent
444834a891
commit
7e7c110851
@ -294,13 +294,21 @@
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div id="editor-inputs-init-image" class="row">
|
<div id="editor-inputs-init-image" class="row">
|
||||||
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /> </button><br/>
|
<label for="init_image"><b>Initial Image:</b> (optional) </label> <input id="init_image" name="init_image" type="file" /><br/>
|
||||||
<div id="init_image_preview_container" class="image_preview_container">
|
<div id="init_image_preview_container" class="image_preview_container">
|
||||||
<img id="init_image_preview" src="" width="100" height="100" />
|
<img id="init_image_preview" src="" width="100" height="100" />
|
||||||
<button id="init_image_clear" class="image_clear_btn">X</button>
|
<button id="init_image_clear" class="image_clear_btn">X</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div id="editor-inputs-mask_setting">
|
||||||
|
<label for="mask"><b>Image Mask:</b> (optional) </label> <input id="mask" name="mask" type="file" /><br/>
|
||||||
|
<div id="mask_preview_container" class="image_preview_container">
|
||||||
|
<img id="mask_preview" src="" width="100" height="100" />
|
||||||
|
<button id="mask_clear" class="image_clear_btn">X</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div id="editor-inputs-tags-container" class="row">
|
<div id="editor-inputs-tags-container" class="row">
|
||||||
<label>Tags: <small>(click a tag to remove it)</small></label>
|
<label>Tags: <small>(click a tag to remove it)</small></label>
|
||||||
<div id="editor-inputs-tags-list">
|
<div id="editor-inputs-tags-list">
|
||||||
@ -453,8 +461,8 @@ let widthField = document.querySelector('#width')
|
|||||||
let heightField = document.querySelector('#height')
|
let heightField = document.querySelector('#height')
|
||||||
let initImageSelector = document.querySelector("#init_image")
|
let initImageSelector = document.querySelector("#init_image")
|
||||||
let initImagePreview = document.querySelector("#init_image_preview")
|
let initImagePreview = document.querySelector("#init_image_preview")
|
||||||
// let maskImageSelector = document.querySelector("#mask")
|
let maskImageSelector = document.querySelector("#mask")
|
||||||
// let maskImagePreview = document.querySelector("#mask_preview")
|
let maskImagePreview = document.querySelector("#mask_preview")
|
||||||
let turboField = document.querySelector('#turbo')
|
let turboField = document.querySelector('#turbo')
|
||||||
let useCPUField = document.querySelector('#use_cpu')
|
let useCPUField = document.querySelector('#use_cpu')
|
||||||
let useFullPrecisionField = document.querySelector('#use_full_precision')
|
let useFullPrecisionField = document.querySelector('#use_full_precision')
|
||||||
@ -479,9 +487,9 @@ let initImagePreviewContainer = document.querySelector('#init_image_preview_cont
|
|||||||
let initImageClearBtn = document.querySelector('#init_image_clear')
|
let initImageClearBtn = document.querySelector('#init_image_clear')
|
||||||
let promptStrengthContainer = document.querySelector('#prompt_strength_container')
|
let promptStrengthContainer = document.querySelector('#prompt_strength_container')
|
||||||
|
|
||||||
// let maskSetting = document.querySelector('#mask_setting')
|
let maskSetting = document.querySelector('#editor-inputs-mask_setting')
|
||||||
// let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
|
let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
|
||||||
// let maskImageClearBtn = document.querySelector('#mask_clear')
|
let maskImageClearBtn = document.querySelector('#mask_clear')
|
||||||
|
|
||||||
let editorModifierEntries = document.querySelector('#editor-modifiers-entries')
|
let editorModifierEntries = document.querySelector('#editor-modifiers-entries')
|
||||||
let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list')
|
let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list')
|
||||||
@ -844,7 +852,7 @@ async function doMakeImage(reqBody, batchCount) {
|
|||||||
initImagePreviewContainer.style.display = 'block'
|
initImagePreviewContainer.style.display = 'block'
|
||||||
promptStrengthContainer.style.display = 'block'
|
promptStrengthContainer.style.display = 'block'
|
||||||
|
|
||||||
// maskSetting.style.display = 'block'
|
maskSetting.style.display = 'block'
|
||||||
|
|
||||||
randomSeedField.checked = false
|
randomSeedField.checked = false
|
||||||
seedField.value = seed
|
seedField.value = seed
|
||||||
@ -949,9 +957,9 @@ async function makeImage() {
|
|||||||
reqBody['init_image'] = initImagePreview.src
|
reqBody['init_image'] = initImagePreview.src
|
||||||
reqBody['prompt_strength'] = promptStrengthField.value
|
reqBody['prompt_strength'] = promptStrengthField.value
|
||||||
|
|
||||||
// if (IMAGE_REGEX.test(maskImagePreview.src)) {
|
if (IMAGE_REGEX.test(maskImagePreview.src)) {
|
||||||
// reqBody['mask'] = maskImagePreview.src
|
reqBody['mask'] = maskImagePreview.src
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
|
if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
|
||||||
@ -1210,7 +1218,7 @@ function showInitImagePreview() {
|
|||||||
if (initImageSelector.files.length === 0) {
|
if (initImageSelector.files.length === 0) {
|
||||||
initImagePreviewContainer.style.display = 'none'
|
initImagePreviewContainer.style.display = 'none'
|
||||||
promptStrengthContainer.style.display = 'none'
|
promptStrengthContainer.style.display = 'none'
|
||||||
// maskSetting.style.display = 'none'
|
maskSetting.style.display = 'none'
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1223,7 +1231,7 @@ function showInitImagePreview() {
|
|||||||
initImagePreviewContainer.style.display = 'block'
|
initImagePreviewContainer.style.display = 'block'
|
||||||
promptStrengthContainer.style.display = 'block'
|
promptStrengthContainer.style.display = 'block'
|
||||||
|
|
||||||
// maskSetting.style.display = 'block'
|
maskSetting.style.display = 'block'
|
||||||
})
|
})
|
||||||
|
|
||||||
if (file) {
|
if (file) {
|
||||||
@ -1235,45 +1243,45 @@ showInitImagePreview()
|
|||||||
|
|
||||||
initImageClearBtn.addEventListener('click', function() {
|
initImageClearBtn.addEventListener('click', function() {
|
||||||
initImageSelector.value = null
|
initImageSelector.value = null
|
||||||
// maskImageSelector.value = null
|
maskImageSelector.value = null
|
||||||
|
|
||||||
initImagePreview.src = ''
|
initImagePreview.src = ''
|
||||||
// maskImagePreview.src = ''
|
maskImagePreview.src = ''
|
||||||
|
|
||||||
initImagePreviewContainer.style.display = 'none'
|
initImagePreviewContainer.style.display = 'none'
|
||||||
// maskImagePreviewContainer.style.display = 'none'
|
maskImagePreviewContainer.style.display = 'none'
|
||||||
|
|
||||||
// maskSetting.style.display = 'none'
|
maskSetting.style.display = 'none'
|
||||||
|
|
||||||
promptStrengthContainer.style.display = 'none'
|
promptStrengthContainer.style.display = 'none'
|
||||||
})
|
})
|
||||||
|
|
||||||
// function showMaskImagePreview() {
|
function showMaskImagePreview() {
|
||||||
// if (maskImageSelector.files.length === 0) {
|
if (maskImageSelector.files.length === 0) {
|
||||||
// maskImagePreviewContainer.style.display = 'none'
|
maskImagePreviewContainer.style.display = 'none'
|
||||||
// return
|
return
|
||||||
// }
|
}
|
||||||
|
|
||||||
// let reader = new FileReader()
|
let reader = new FileReader()
|
||||||
// let file = maskImageSelector.files[0]
|
let file = maskImageSelector.files[0]
|
||||||
|
|
||||||
// reader.addEventListener('load', function() {
|
reader.addEventListener('load', function() {
|
||||||
// maskImagePreview.src = reader.result
|
maskImagePreview.src = reader.result
|
||||||
// maskImagePreviewContainer.style.display = 'block'
|
maskImagePreviewContainer.style.display = 'block'
|
||||||
// })
|
})
|
||||||
|
|
||||||
// if (file) {
|
if (file) {
|
||||||
// reader.readAsDataURL(file)
|
reader.readAsDataURL(file)
|
||||||
// }
|
}
|
||||||
// }
|
}
|
||||||
// maskImageSelector.addEventListener('change', showMaskImagePreview)
|
maskImageSelector.addEventListener('change', showMaskImagePreview)
|
||||||
// showMaskImagePreview()
|
showMaskImagePreview()
|
||||||
|
|
||||||
// maskImageClearBtn.addEventListener('click', function() {
|
maskImageClearBtn.addEventListener('click', function() {
|
||||||
// maskImageSelector.value = null
|
maskImageSelector.value = null
|
||||||
// maskImagePreview.src = ''
|
maskImagePreview.src = ''
|
||||||
// maskImagePreviewContainer.style.display = 'none'
|
maskImagePreviewContainer.style.display = 'none'
|
||||||
// })
|
})
|
||||||
|
|
||||||
// https://stackoverflow.com/a/8212878
|
// https://stackoverflow.com/a/8212878
|
||||||
function millisecondsToStr(milliseconds) {
|
function millisecondsToStr(milliseconds) {
|
||||||
|
@ -4,7 +4,7 @@ import traceback
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image, ImageOps
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
|
|||||||
from . import Request, Response, Image as ResponseImage
|
from . import Request, Response, Image as ResponseImage
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from colorama import Fore
|
||||||
|
|
||||||
# local
|
# local
|
||||||
stop_processing = False
|
stop_processing = False
|
||||||
@ -267,6 +268,8 @@ def mk_img(req: Request):
|
|||||||
else:
|
else:
|
||||||
precision_scope = nullcontext
|
precision_scope = nullcontext
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
|
||||||
if req.init_image is None:
|
if req.init_image is None:
|
||||||
handler = _txt2img
|
handler = _txt2img
|
||||||
|
|
||||||
@ -286,6 +289,14 @@ def mk_img(req: Request):
|
|||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
|
if req.mask is not None:
|
||||||
|
mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||||
|
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||||
|
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||||
|
|
||||||
|
if device != "cpu" and precision == "autocast":
|
||||||
|
mask = mask.half()
|
||||||
|
|
||||||
if device != "cpu":
|
if device != "cpu":
|
||||||
mem = torch.cuda.memory_allocated() / 1e6
|
mem = torch.cuda.memory_allocated() / 1e6
|
||||||
modelFS.to("cpu")
|
modelFS.to("cpu")
|
||||||
@ -365,9 +376,9 @@ def mk_img(req: Request):
|
|||||||
# run the handler
|
# run the handler
|
||||||
try:
|
try:
|
||||||
if handler == _txt2img:
|
if handler == _txt2img:
|
||||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
|
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||||
else:
|
else:
|
||||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
|
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||||
|
|
||||||
if req.stream_progress_updates:
|
if req.stream_progress_updates:
|
||||||
yield from x_samples
|
yield from x_samples
|
||||||
@ -414,6 +425,8 @@ def mk_img(req: Request):
|
|||||||
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
||||||
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
|
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
|
||||||
|
|
||||||
|
print('Applying filters..')
|
||||||
|
|
||||||
gc()
|
gc()
|
||||||
filters_applied = []
|
filters_applied = []
|
||||||
|
|
||||||
@ -451,6 +464,8 @@ def mk_img(req: Request):
|
|||||||
del x_samples
|
del x_samples
|
||||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||||
|
|
||||||
|
print(Fore.GREEN + 'Task completed')
|
||||||
|
|
||||||
if req.stream_progress_updates:
|
if req.stream_progress_updates:
|
||||||
yield json.dumps(res.json())
|
yield json.dumps(res.json())
|
||||||
else:
|
else:
|
||||||
@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
|||||||
except:
|
except:
|
||||||
print('could not save the file', traceback.format_exc())
|
print('could not save the file', traceback.format_exc())
|
||||||
|
|
||||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
|
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
|
||||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
if device != "cpu":
|
if device != "cpu":
|
||||||
@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
x_T=start_code,
|
x_T=start_code,
|
||||||
img_callback=img_callback,
|
img_callback=img_callback,
|
||||||
streaming_callbacks=streaming_callbacks,
|
streaming_callbacks=streaming_callbacks,
|
||||||
|
mask=mask,
|
||||||
sampler = 'plms',
|
sampler = 'plms',
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
else:
|
else:
|
||||||
return samples_ddim
|
return samples_ddim
|
||||||
|
|
||||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
|
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = model.stochastic_encode(
|
z_enc = model.stochastic_encode(
|
||||||
init_latent,
|
init_latent,
|
||||||
@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
|||||||
opt_ddim_eta,
|
opt_ddim_eta,
|
||||||
opt_ddim_steps,
|
opt_ddim_steps,
|
||||||
)
|
)
|
||||||
|
x_T = None if mask is None else init_latent
|
||||||
|
|
||||||
# decode it
|
# decode it
|
||||||
samples_ddim = model.sample(
|
samples_ddim = model.sample(
|
||||||
t_enc,
|
t_enc,
|
||||||
@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
|||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
img_callback=img_callback,
|
img_callback=img_callback,
|
||||||
streaming_callbacks=streaming_callbacks,
|
streaming_callbacks=streaming_callbacks,
|
||||||
|
mask=mask,
|
||||||
|
x_T=x_T,
|
||||||
sampler = 'ddim'
|
sampler = 'ddim'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
|
|||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
return 2.*image - 1.
|
return 2.*image - 1.
|
||||||
|
|
||||||
|
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
||||||
|
image = base64_str_to_img(mask_str).convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
print(f"loaded input mask of size ({w}, {h})")
|
||||||
|
|
||||||
|
if invert:
|
||||||
|
print("inverted")
|
||||||
|
image = ImageOps.invert(image)
|
||||||
|
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
|
||||||
|
# image[where_0], image[where_1] = 255, 0
|
||||||
|
|
||||||
|
if h0 is not None and w0 is not None:
|
||||||
|
h, w = h0, w0
|
||||||
|
|
||||||
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||||
|
|
||||||
|
print(f"New mask size ({w}, {h})")
|
||||||
|
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
|
||||||
|
image = np.array(image)
|
||||||
|
|
||||||
|
image = image.astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
return image
|
||||||
|
|
||||||
# https://stackoverflow.com/a/61114178
|
# https://stackoverflow.com/a/61114178
|
||||||
def img_to_base64_str(img):
|
def img_to_base64_str(img):
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
|
Loading…
Reference in New Issue
Block a user