diff --git a/ui/index.html b/ui/index.html
index bfe8b480..497a3126 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -294,13 +294,21 @@
@@ -453,8 +461,8 @@ let widthField = document.querySelector('#width')
let heightField = document.querySelector('#height')
let initImageSelector = document.querySelector("#init_image")
let initImagePreview = document.querySelector("#init_image_preview")
-// let maskImageSelector = document.querySelector("#mask")
-// let maskImagePreview = document.querySelector("#mask_preview")
+let maskImageSelector = document.querySelector("#mask")
+let maskImagePreview = document.querySelector("#mask_preview")
let turboField = document.querySelector('#turbo')
let useCPUField = document.querySelector('#use_cpu')
let useFullPrecisionField = document.querySelector('#use_full_precision')
@@ -479,9 +487,9 @@ let initImagePreviewContainer = document.querySelector('#init_image_preview_cont
let initImageClearBtn = document.querySelector('#init_image_clear')
let promptStrengthContainer = document.querySelector('#prompt_strength_container')
-// let maskSetting = document.querySelector('#mask_setting')
-// let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
-// let maskImageClearBtn = document.querySelector('#mask_clear')
+let maskSetting = document.querySelector('#editor-inputs-mask_setting')
+let maskImagePreviewContainer = document.querySelector('#mask_preview_container')
+let maskImageClearBtn = document.querySelector('#mask_clear')
let editorModifierEntries = document.querySelector('#editor-modifiers-entries')
let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list')
@@ -844,7 +852,7 @@ async function doMakeImage(reqBody, batchCount) {
initImagePreviewContainer.style.display = 'block'
promptStrengthContainer.style.display = 'block'
- // maskSetting.style.display = 'block'
+ maskSetting.style.display = 'block'
randomSeedField.checked = false
seedField.value = seed
@@ -949,9 +957,9 @@ async function makeImage() {
reqBody['init_image'] = initImagePreview.src
reqBody['prompt_strength'] = promptStrengthField.value
- // if (IMAGE_REGEX.test(maskImagePreview.src)) {
- // reqBody['mask'] = maskImagePreview.src
- // }
+ if (IMAGE_REGEX.test(maskImagePreview.src)) {
+ reqBody['mask'] = maskImagePreview.src
+ }
}
if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
@@ -1210,7 +1218,7 @@ function showInitImagePreview() {
if (initImageSelector.files.length === 0) {
initImagePreviewContainer.style.display = 'none'
promptStrengthContainer.style.display = 'none'
- // maskSetting.style.display = 'none'
+ maskSetting.style.display = 'none'
return
}
@@ -1223,7 +1231,7 @@ function showInitImagePreview() {
initImagePreviewContainer.style.display = 'block'
promptStrengthContainer.style.display = 'block'
- // maskSetting.style.display = 'block'
+ maskSetting.style.display = 'block'
})
if (file) {
@@ -1235,45 +1243,45 @@ showInitImagePreview()
initImageClearBtn.addEventListener('click', function() {
initImageSelector.value = null
- // maskImageSelector.value = null
+ maskImageSelector.value = null
initImagePreview.src = ''
- // maskImagePreview.src = ''
+ maskImagePreview.src = ''
initImagePreviewContainer.style.display = 'none'
- // maskImagePreviewContainer.style.display = 'none'
+ maskImagePreviewContainer.style.display = 'none'
- // maskSetting.style.display = 'none'
+ maskSetting.style.display = 'none'
promptStrengthContainer.style.display = 'none'
})
-// function showMaskImagePreview() {
-// if (maskImageSelector.files.length === 0) {
-// maskImagePreviewContainer.style.display = 'none'
-// return
-// }
+function showMaskImagePreview() {
+ if (maskImageSelector.files.length === 0) {
+ maskImagePreviewContainer.style.display = 'none'
+ return
+ }
-// let reader = new FileReader()
-// let file = maskImageSelector.files[0]
+ let reader = new FileReader()
+ let file = maskImageSelector.files[0]
-// reader.addEventListener('load', function() {
-// maskImagePreview.src = reader.result
-// maskImagePreviewContainer.style.display = 'block'
-// })
+ reader.addEventListener('load', function() {
+ maskImagePreview.src = reader.result
+ maskImagePreviewContainer.style.display = 'block'
+ })
-// if (file) {
-// reader.readAsDataURL(file)
-// }
-// }
-// maskImageSelector.addEventListener('change', showMaskImagePreview)
-// showMaskImagePreview()
+ if (file) {
+ reader.readAsDataURL(file)
+ }
+}
+maskImageSelector.addEventListener('change', showMaskImagePreview)
+showMaskImagePreview()
-// maskImageClearBtn.addEventListener('click', function() {
-// maskImageSelector.value = null
-// maskImagePreview.src = ''
-// maskImagePreviewContainer.style.display = 'none'
-// })
+maskImageClearBtn.addEventListener('click', function() {
+ maskImageSelector.value = null
+ maskImagePreview.src = ''
+ maskImagePreviewContainer.style.display = 'none'
+})
// https://stackoverflow.com/a/8212878
function millisecondsToStr(milliseconds) {
diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py
index 613cb967..04b59d32 100644
--- a/ui/sd_internal/runtime.py
+++ b/ui/sd_internal/runtime.py
@@ -4,7 +4,7 @@ import traceback
import torch
import numpy as np
from omegaconf import OmegaConf
-from PIL import Image
+from PIL import Image, ImageOps
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
@@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]')
from . import Request, Response, Image as ResponseImage
import base64
from io import BytesIO
+from colorama import Fore
# local
stop_processing = False
@@ -267,6 +268,8 @@ def mk_img(req: Request):
else:
precision_scope = nullcontext
+ mask = None
+
if req.init_image is None:
handler = _txt2img
@@ -286,6 +289,14 @@ def mk_img(req: Request):
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
+ if req.mask is not None:
+ mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device)
+ mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
+ mask = repeat(mask, '1 ... -> b ...', b=batch_size)
+
+ if device != "cpu" and precision == "autocast":
+ mask = mask.half()
+
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
@@ -365,9 +376,9 @@ def mk_img(req: Request):
# run the handler
try:
if handler == _txt2img:
- x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates)
+ x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
else:
- x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates)
+ x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
if req.stream_progress_updates:
yield from x_samples
@@ -414,6 +425,8 @@ def mk_img(req: Request):
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
+ print('Applying filters..')
+
gc()
filters_applied = []
@@ -451,6 +464,8 @@ def mk_img(req: Request):
del x_samples
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
+ print(Fore.GREEN + 'Task completed')
+
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
@@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
-def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks):
+def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
x_T=start_code,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
+ mask=mask,
sampler = 'plms',
)
@@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
else:
return samples_ddim
-def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks):
+def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
opt_ddim_eta,
opt_ddim_steps,
)
+ x_T = None if mask is None else init_latent
+
# decode it
samples_ddim = model.sample(
t_enc,
@@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_conditioning=uc,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
+ mask=mask,
+ x_T=x_T,
sampler = 'ddim'
)
@@ -566,6 +586,31 @@ def load_img(img_str, w0, h0):
image = torch.from_numpy(image)
return 2.*image - 1.
+def load_mask(mask_str, h0, w0, newH, newW, invert=False):
+ image = base64_str_to_img(mask_str).convert("RGB")
+ w, h = image.size
+ print(f"loaded input mask of size ({w}, {h})")
+
+ if invert:
+ print("inverted")
+ image = ImageOps.invert(image)
+ # where_0, where_1 = np.where(image == 0), np.where(image == 255)
+ # image[where_0], image[where_1] = 255, 0
+
+ if h0 is not None and w0 is not None:
+ h, w = h0, w0
+
+ w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
+
+ print(f"New mask size ({w}, {h})")
+ image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
+ image = np.array(image)
+
+ image = image.astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return image
+
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img):
buffered = BytesIO()