From ee6db857681a7a86befbe6e24cea04f095c3bd3d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 1 Aug 2023 15:39:15 +0530 Subject: [PATCH] Initial support for Controlnet --- ui/easydiffusion/model_manager.py | 9 +++- ui/easydiffusion/tasks/render_images.py | 3 ++ ui/easydiffusion/types.py | 5 ++ ui/index.html | 59 ++++++++++++++++++++-- ui/media/css/main.css | 16 ++++-- ui/media/js/main.js | 65 +++++++++++++++++++++++++ ui/media/js/searchable-models.js | 5 +- 7 files changed, 153 insertions(+), 9 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 1ee5ce9d..2e7fef67 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -9,6 +9,7 @@ from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db +from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.utils import hash_file_quick KNOWN_MODEL_TYPES = [ @@ -19,6 +20,8 @@ KNOWN_MODEL_TYPES = [ "realesrgan", "lora", "codeformer", + "embeddings", + "controlnet", ] MODEL_EXTENSIONS = { "stable-diffusion": [".ckpt", ".safetensors"], @@ -29,6 +32,7 @@ MODEL_EXTENSIONS = { "lora": [".ckpt", ".safetensors"], "codeformer": [".pth"], "embeddings": [".pt", ".bin", ".safetensors"], + "controlnet": [".pth", ".safetensors"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -177,7 +181,8 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models def resolve_model_paths(models_data: ModelsData): model_paths = models_data.model_paths for model_type in model_paths: - if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths + skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"] + if model_type in skip_models: # doesn't use model paths continue if model_type == "codeformer": download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") @@ -291,6 +296,7 @@ def getModels(scan_for_malicious: bool = True): "lora": [], "codeformer": ["codeformer"], "embeddings": [], + "controlnet": [], }, } @@ -350,6 +356,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="gfpgan") listModels(model_type="lora") listModels(model_type="embeddings") + listModels(model_type="controlnet") if scan_for_malicious and models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index f14d478d..8df208b6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -210,6 +210,9 @@ def generate_images_internal( if req.init_image is not None and not context.test_diffusers: req.sampler_name = "ddim" + if req.control_image and task_data.control_filter_to_apply: + req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0] + if context.test_diffusers: pipe = context.models["stable-diffusion"]["default"] if hasattr(pipe.unet, "_allocate_trt_buffers"): diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 894867b8..181a9505 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -75,6 +75,7 @@ class TaskData(BaseModel): use_controlnet_model: Union[str, List[str]] = None filters: List[str] = [] filter_params: Dict[str, Dict[str, Any]] = {} + control_filter_to_apply: Union[str, List[str]] = None show_only_filtered_image: bool = False block_nsfw: bool = False @@ -135,6 +136,7 @@ class GenerateImageResponse: def json(self): del self.render_request.init_image del self.render_request.init_image_mask + del self.render_request.control_image task_data = self.task_data.dict() task_data.update(self.output_format.dict()) @@ -212,6 +214,9 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["latent_upscaler"] = ( model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None ) + if "control_filter_to_apply" in old_req: + filter_model = old_req["control_filter_to_apply"] + model_paths[filter_model] = filter_model if old_req.get("block_nsfw"): model_paths["nsfw_checker"] = "nsfw_checker" diff --git a/ui/index.html b/ui/index.html index d2a6194c..68616ed4 100644 --- a/ui/index.html +++ b/ui/index.html @@ -83,8 +83,8 @@
-
- +
+
@@ -151,7 +151,6 @@ - Click to learn more about TensorRT @@ -162,6 +161,58 @@ Click to learn more about Clip Skip + +
+ + + +
+ + Click to learn more about ControlNets +
+ + +
+ +
+ Click to learn more about VAEs @@ -239,7 +290,7 @@
- Recent sizes + Advanced sizes
Custom size:
diff --git a/ui/media/css/main.css b/ui/media/css/main.css index 5e1cee43..3e07448e 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -794,7 +794,7 @@ div.img-preview img { margin-bottom: 8px; } -#init_image_preview_container:not(.has-image) #init_image_wrapper, +#init_image_preview_container:not(.has-image) .preview_image_wrapper, #init_image_preview_container:not(.has-image) #inpaint_button_container { display: none; } @@ -831,14 +831,14 @@ div.img-preview img { gap: 8px; } -#init_image_wrapper { +.preview_image_wrapper { grid-row: span 3; position: relative; width: fit-content; max-height: 150px; } -#init_image_preview { +.image_preview { max-height: 150px; height: 100%; width: 100%; @@ -1817,3 +1817,13 @@ div#enlarge-buttons { .imgContainer .spinnerStatus { font-size: 10pt; } + +#controlnet_model_container small { + color: var(--text-color) +} +#control_image { + width: 130pt; +} +#controlnet_model { + width: 77%; +} \ No newline at end of file diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 27724ee1..97ce807a 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -93,6 +93,11 @@ let initImagePreview = document.querySelector("#init_image_preview") let initImageSizeBox = document.querySelector("#init_image_size_box") let maskImageSelector = document.querySelector("#mask") let maskImagePreview = document.querySelector("#mask_preview") +let controlImageSelector = document.querySelector("#control_image") +let controlImagePreview = document.querySelector("#control_image_preview") +let controlImageClearBtn = document.querySelector(".control_image_clear") +let controlImageContainer = document.querySelector("#control_image_wrapper") +let controlImageFilterField = document.querySelector("#control_image_filter") let applyColorCorrectionField = document.querySelector("#apply_color_correction") let strictMaskBorderField = document.querySelector("#strict_mask_border") let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting") @@ -114,6 +119,7 @@ let codeformerFidelityField = document.querySelector("#codeformer_fidelity") let stableDiffusionModelField = new ModelDropdown(document.querySelector("#stable_diffusion_model"), "stable-diffusion") let clipSkipField = document.querySelector("#clip_skip") let tilingField = document.querySelector("#tiling") +let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None") let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None") let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None") let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider") @@ -1447,6 +1453,13 @@ function getCurrentUserRequest() { // TRT is installed newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked } + if (controlnetModelField.value !== "" && IMAGE_REGEX.test(controlImagePreview.src)) { + newTask.reqBody.use_controlnet_model = controlnetModelField.value + newTask.reqBody.control_image = controlImagePreview.src + if (controlImageFilterField.value !== "") { + newTask.reqBody.control_filter_to_apply = controlImageFilterField.value + } + } return newTask } @@ -1853,6 +1866,20 @@ function onFixFaceModelChange() { gfpganModelField.addEventListener("change", onFixFaceModelChange) onFixFaceModelChange() +function onControlnetModelChange() { + let configBox = document.querySelector("#controlnet_config") + if (IMAGE_REGEX.test(controlImagePreview.src)) { + configBox.classList.remove("displayNone") + controlImageContainer.classList.remove("displayNone") + } else { + configBox.classList.add("displayNone") + controlImageContainer.classList.add("displayNone") + } +} +controlImagePreview.addEventListener("load", onControlnetModelChange) +controlImagePreview.addEventListener("unload", onControlnetModelChange) +onControlnetModelChange() + upscaleModelField.disabled = !useUpscalingField.checked upscaleAmountField.disabled = !useUpscalingField.checked useUpscalingField.addEventListener("change", function(e) { @@ -2143,6 +2170,44 @@ promptsFromFileBtn.addEventListener("click", function() { promptsFromFileSelector.click() }) +function loadControlnetImageFromFile() { + if (controlImageSelector.files.length === 0) { + return + } + + let reader = new FileReader() + let file = controlImageSelector.files[0] + + reader.addEventListener("load", function(event) { + controlImagePreview.src = reader.result + }) + + if (file) { + reader.readAsDataURL(file) + } +} +controlImageSelector.addEventListener("change", loadControlnetImageFromFile) + +function controlImageLoad() { + let w = controlImagePreview.naturalWidth + let h = controlImagePreview.naturalHeight + addImageSizeOption(w) + addImageSizeOption(h) + + widthField.value = w + heightField.value = h + widthField.dispatchEvent(new Event("change")) + heightField.dispatchEvent(new Event("change")) +} +controlImagePreview.addEventListener("load", controlImageLoad) + +function controlImageUnload() { + controlImageSelector.value = null + controlImagePreview.src = "" + controlImagePreview.dispatchEvent(new Event("unload")) +} +controlImageClearBtn.addEventListener("click", controlImageUnload) + promptsFromFileSelector.addEventListener("change", async function() { if (promptsFromFileSelector.files.length === 0) { return diff --git a/ui/media/js/searchable-models.js b/ui/media/js/searchable-models.js index 85b9bbe9..174faf77 100644 --- a/ui/media/js/searchable-models.js +++ b/ui/media/js/searchable-models.js @@ -667,4 +667,7 @@ async function getModels(scanForMalicious = true) { } // reload models button -document.querySelector("#reload-models").addEventListener("click", () => getModels()) +document.querySelector("#reload-models").addEventListener("click", (e) => { + e.stopPropagation() + getModels() +})