diff --git a/ui/index.html b/ui/index.html index bfe8b480..497a3126 100644 --- a/ui/index.html +++ b/ui/index.html @@ -294,13 +294,21 @@
-
+
+
+
+
+ + +
+
+
@@ -453,8 +461,8 @@ let widthField = document.querySelector('#width') let heightField = document.querySelector('#height') let initImageSelector = document.querySelector("#init_image") let initImagePreview = document.querySelector("#init_image_preview") -// let maskImageSelector = document.querySelector("#mask") -// let maskImagePreview = document.querySelector("#mask_preview") +let maskImageSelector = document.querySelector("#mask") +let maskImagePreview = document.querySelector("#mask_preview") let turboField = document.querySelector('#turbo') let useCPUField = document.querySelector('#use_cpu') let useFullPrecisionField = document.querySelector('#use_full_precision') @@ -479,9 +487,9 @@ let initImagePreviewContainer = document.querySelector('#init_image_preview_cont let initImageClearBtn = document.querySelector('#init_image_clear') let promptStrengthContainer = document.querySelector('#prompt_strength_container') -// let maskSetting = document.querySelector('#mask_setting') -// let maskImagePreviewContainer = document.querySelector('#mask_preview_container') -// let maskImageClearBtn = document.querySelector('#mask_clear') +let maskSetting = document.querySelector('#editor-inputs-mask_setting') +let maskImagePreviewContainer = document.querySelector('#mask_preview_container') +let maskImageClearBtn = document.querySelector('#mask_clear') let editorModifierEntries = document.querySelector('#editor-modifiers-entries') let editorModifierTagsList = document.querySelector('#editor-inputs-tags-list') @@ -844,7 +852,7 @@ async function doMakeImage(reqBody, batchCount) { initImagePreviewContainer.style.display = 'block' promptStrengthContainer.style.display = 'block' - // maskSetting.style.display = 'block' + maskSetting.style.display = 'block' randomSeedField.checked = false seedField.value = seed @@ -949,9 +957,9 @@ async function makeImage() { reqBody['init_image'] = initImagePreview.src reqBody['prompt_strength'] = promptStrengthField.value - // if (IMAGE_REGEX.test(maskImagePreview.src)) { - // reqBody['mask'] = maskImagePreview.src - // } + if (IMAGE_REGEX.test(maskImagePreview.src)) { + reqBody['mask'] = maskImagePreview.src + } } if (saveToDiskField.checked && diskPathField.value.trim() !== '') { @@ -1210,7 +1218,7 @@ function showInitImagePreview() { if (initImageSelector.files.length === 0) { initImagePreviewContainer.style.display = 'none' promptStrengthContainer.style.display = 'none' - // maskSetting.style.display = 'none' + maskSetting.style.display = 'none' return } @@ -1223,7 +1231,7 @@ function showInitImagePreview() { initImagePreviewContainer.style.display = 'block' promptStrengthContainer.style.display = 'block' - // maskSetting.style.display = 'block' + maskSetting.style.display = 'block' }) if (file) { @@ -1235,45 +1243,45 @@ showInitImagePreview() initImageClearBtn.addEventListener('click', function() { initImageSelector.value = null - // maskImageSelector.value = null + maskImageSelector.value = null initImagePreview.src = '' - // maskImagePreview.src = '' + maskImagePreview.src = '' initImagePreviewContainer.style.display = 'none' - // maskImagePreviewContainer.style.display = 'none' + maskImagePreviewContainer.style.display = 'none' - // maskSetting.style.display = 'none' + maskSetting.style.display = 'none' promptStrengthContainer.style.display = 'none' }) -// function showMaskImagePreview() { -// if (maskImageSelector.files.length === 0) { -// maskImagePreviewContainer.style.display = 'none' -// return -// } +function showMaskImagePreview() { + if (maskImageSelector.files.length === 0) { + maskImagePreviewContainer.style.display = 'none' + return + } -// let reader = new FileReader() -// let file = maskImageSelector.files[0] + let reader = new FileReader() + let file = maskImageSelector.files[0] -// reader.addEventListener('load', function() { -// maskImagePreview.src = reader.result -// maskImagePreviewContainer.style.display = 'block' -// }) + reader.addEventListener('load', function() { + maskImagePreview.src = reader.result + maskImagePreviewContainer.style.display = 'block' + }) -// if (file) { -// reader.readAsDataURL(file) -// } -// } -// maskImageSelector.addEventListener('change', showMaskImagePreview) -// showMaskImagePreview() + if (file) { + reader.readAsDataURL(file) + } +} +maskImageSelector.addEventListener('change', showMaskImagePreview) +showMaskImagePreview() -// maskImageClearBtn.addEventListener('click', function() { -// maskImageSelector.value = null -// maskImagePreview.src = '' -// maskImagePreviewContainer.style.display = 'none' -// }) +maskImageClearBtn.addEventListener('click', function() { + maskImageSelector.value = null + maskImagePreview.src = '' + maskImagePreviewContainer.style.display = 'none' +}) // https://stackoverflow.com/a/8212878 function millisecondsToStr(milliseconds) { diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 613cb967..04b59d32 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -4,7 +4,7 @@ import traceback import torch import numpy as np from omegaconf import OmegaConf -from PIL import Image +from PIL import Image, ImageOps from tqdm import tqdm, trange from itertools import islice from einops import rearrange @@ -33,6 +33,7 @@ filename_regex = re.compile('[^a-zA-Z0-9]') from . import Request, Response, Image as ResponseImage import base64 from io import BytesIO +from colorama import Fore # local stop_processing = False @@ -267,6 +268,8 @@ def mk_img(req: Request): else: precision_scope = nullcontext + mask = None + if req.init_image is None: handler = _txt2img @@ -286,6 +289,14 @@ def mk_img(req: Request): init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space + if req.mask is not None: + mask = load_mask(req.mask, opt_W, opt_H, init_latent.shape[2], init_latent.shape[3], True).to(device) + mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) + mask = repeat(mask, '1 ... -> b ...', b=batch_size) + + if device != "cpu" and precision == "autocast": + mask = mask.half() + if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelFS.to("cpu") @@ -365,9 +376,9 @@ def mk_img(req: Request): # run the handler try: if handler == _txt2img: - x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates) + x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask) else: - x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates) + x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask) if req.stream_progress_updates: yield from x_samples @@ -414,6 +425,8 @@ def mk_img(req: Request): if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \ (opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')): + print('Applying filters..') + gc() filters_applied = [] @@ -451,6 +464,8 @@ def mk_img(req: Request): del x_samples print("memory_final = ", torch.cuda.memory_allocated() / 1e6) + print(Fore.GREEN + 'Task completed') + if req.stream_progress_updates: yield json.dumps(res.json()) else: @@ -471,7 +486,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps except: print('could not save the file', traceback.format_exc()) -def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks): +def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask): shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] if device != "cpu": @@ -492,6 +507,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, x_T=start_code, img_callback=img_callback, streaming_callbacks=streaming_callbacks, + mask=mask, sampler = 'plms', ) @@ -500,7 +516,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, else: return samples_ddim -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks): +def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask): # encode (scaled latent) z_enc = model.stochastic_encode( init_latent, @@ -509,6 +525,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o opt_ddim_eta, opt_ddim_steps, ) + x_T = None if mask is None else init_latent + # decode it samples_ddim = model.sample( t_enc, @@ -518,6 +536,8 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o unconditional_conditioning=uc, img_callback=img_callback, streaming_callbacks=streaming_callbacks, + mask=mask, + x_T=x_T, sampler = 'ddim' ) @@ -566,6 +586,31 @@ def load_img(img_str, w0, h0): image = torch.from_numpy(image) return 2.*image - 1. +def load_mask(mask_str, h0, w0, newH, newW, invert=False): + image = base64_str_to_img(mask_str).convert("RGB") + w, h = image.size + print(f"loaded input mask of size ({w}, {h})") + + if invert: + print("inverted") + image = ImageOps.invert(image) + # where_0, where_1 = np.where(image == 0), np.where(image == 255) + # image[where_0], image[where_1] = 255, 0 + + if h0 is not None and w0 is not None: + h, w = h0, w0 + + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + + print(f"New mask size ({w}, {h})") + image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS) + image = np.array(image) + + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + # https://stackoverflow.com/a/61114178 def img_to_base64_str(img): buffered = BytesIO()