diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 324dcec9..d6a227be 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -107,12 +107,15 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None): def reload_models_if_necessary(context: Context, task_data: TaskData): + use_upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else "" + model_paths_in_req = { "stable-diffusion": task_data.use_stable_diffusion_model, "vae": task_data.use_vae_model, "hypernetwork": task_data.use_hypernetwork_model, "gfpgan": task_data.use_face_correction, - "realesrgan": task_data.use_upscale, + "realesrgan": task_data.use_upscale if "realesrgan" in use_upscale_lower else None, + "latent_upscaler": True if task_data.use_upscale == "latent_upscaler" else None, "nsfw_checker": True if task_data.block_nsfw else None, "lora": task_data.use_lora_model, } @@ -142,7 +145,7 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan") - if task_data.use_upscale: + if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan") diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index e26b4389..c60c42df 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -95,7 +95,7 @@ def make_images_internal( task_data.stream_image_progress_interval, ) gc(context) - filtered_images = filter_images(task_data, images, user_stopped) + filtered_images = filter_images(req, task_data, images, user_stopped) if task_data.save_to_disk_path is not None: save_images_to_disk(images, filtered_images, req, task_data) @@ -151,22 +151,36 @@ def generate_images_internal( return images, user_stopped -def filter_images(task_data: TaskData, images: list, user_stopped): +def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, user_stopped): if user_stopped: return images filters_to_apply = [] + filter_params = {} if task_data.block_nsfw: filters_to_apply.append("nsfw_checker") if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): filters_to_apply.append("gfpgan") - if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): - filters_to_apply.append("realesrgan") + if task_data.use_upscale: + if "realesrgan" in task_data.use_upscale.lower(): + filters_to_apply.append("realesrgan") + elif task_data.use_upscale == "latent_upscaler": + filters_to_apply.append("latent_upscaler") + + filter_params["latent_upscaler_options"] = { + "prompt": req.prompt, + "negative_prompt": req.negative_prompt, + "seed": req.seed, + "num_inference_steps": task_data.latent_upscaler_steps, + "guidance_scale": 0, + } + + filter_params["scale"] = task_data.upscale_amount if len(filters_to_apply) == 0: return images - return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount) + return apply_filters(context, filters_to_apply, images, **filter_params) def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 7a5201ab..a76f489a 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -32,8 +32,9 @@ class TaskData(BaseModel): vram_usage_level: str = "balanced" # or "low" or "medium" use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" or "latent_upscaler" upscale_amount: int = 4 # or 2 + latent_upscaler_steps: int = 10 use_stable_diffusion_model: str = "sd-v1-4" # use_stable_diffusion_config: str = "v1-inference" use_vae_model: str = None diff --git a/ui/index.html b/ui/index.html index 99087eec..5097d84a 100644 --- a/ui/index.html +++ b/ui/index.html @@ -258,14 +258,18 @@
  • with +
    + +
  • diff --git a/ui/media/css/main.css b/ui/media/css/main.css index ba513237..8f4f49fa 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -1303,6 +1303,12 @@ body.wait-pause { display:none !important; } +#latent_upscaler_settings { + padding-top: 3pt; + padding-bottom: 3pt; + padding-left: 5pt; +} + /* TOAST NOTIFICATIONS */ .toast-notification { position: fixed; diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 0ce32f2b..23ed5f46 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -86,6 +86,9 @@ let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model") let useUpscalingField = document.querySelector("#use_upscale") let upscaleModelField = document.querySelector("#upscale_model") let upscaleAmountField = document.querySelector("#upscale_amount") +let latentUpscalerSettings = document.querySelector("#latent_upscaler_settings") +let latentUpscalerStepsSlider = document.querySelector("#latent_upscaler_steps_slider") +let latentUpscalerStepsField = document.querySelector("#latent_upscaler_steps") let stableDiffusionModelField = new ModelDropdown(document.querySelector("#stable_diffusion_model"), "stable-diffusion") let clipSkipField = document.querySelector("#clip_skip") let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None") @@ -239,7 +242,7 @@ function setServerStatus(event) { break } if (SD.serverState.devices) { - document.dispatchEvent(new CustomEvent("system_info_update", { detail: SD.serverState.devices})) + document.dispatchEvent(new CustomEvent("system_info_update", { detail: SD.serverState.devices })) } } @@ -1268,6 +1271,10 @@ function getCurrentUserRequest() { if (useUpscalingField.checked) { newTask.reqBody.use_upscale = upscaleModelField.value newTask.reqBody.upscale_amount = upscaleAmountField.value + if (upscaleModelField.value === "latent_upscaler") { + newTask.reqBody.upscale_amount = "2" + newTask.reqBody.latent_upscaler_steps = latentUpscalerStepsField.value + } } if (hypernetworkModelField.value) { newTask.reqBody.use_hypernetwork_model = hypernetworkModelField.value @@ -1582,6 +1589,20 @@ useUpscalingField.addEventListener("change", function(e) { upscaleAmountField.disabled = !this.checked }) +function onUpscaleModelChange() { + let upscale4x = document.querySelector("#upscale_amount_4x") + if (upscaleModelField.value === "latent_upscaler") { + upscale4x.disabled = true + upscaleAmountField.value = "2" + latentUpscalerSettings.classList.remove("displayNone") + } else { + upscale4x.disabled = false + latentUpscalerSettings.classList.add("displayNone") + } +} +upscaleModelField.addEventListener("change", onUpscaleModelChange) +onUpscaleModelChange() + makeImageBtn.addEventListener("click", makeImage) document.onkeydown = function(e) { @@ -1591,6 +1612,27 @@ document.onkeydown = function(e) { } } +/********************* Latent Upscaler Steps **************************/ +function updateLatentUpscalerSteps() { + latentUpscalerStepsField.value = latentUpscalerStepsSlider.value + latentUpscalerStepsField.dispatchEvent(new Event("change")) +} + +function updateLatentUpscalerStepsSlider() { + if (latentUpscalerStepsField.value < 1) { + latentUpscalerStepsField.value = 1 + } else if (latentUpscalerStepsField.value > 50) { + latentUpscalerStepsField.value = 50 + } + + latentUpscalerStepsSlider.value = latentUpscalerStepsField.value + latentUpscalerStepsSlider.dispatchEvent(new Event("change")) +} + +latentUpscalerStepsSlider.addEventListener("input", updateLatentUpscalerSteps) +latentUpscalerStepsField.addEventListener("input", updateLatentUpscalerStepsSlider) +updateLatentUpscalerSteps() + /********************* Guidance **************************/ function updateGuidanceScale() { guidanceScaleField.value = guidanceScaleSlider.value / 10