Initial support for Controlnet

This commit is contained in:
cmdr2 2023-08-01 15:39:15 +05:30
parent 05ed110519
commit ee6db85768
7 changed files with 153 additions and 9 deletions

View File

@ -9,6 +9,7 @@ from easydiffusion.types import ModelsData
from easydiffusion.utils import log
from sdkit import Context
from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db
from sdkit.models.model_loader.controlnet_filters import filters as cn_filters
from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = [
@ -19,6 +20,8 @@ KNOWN_MODEL_TYPES = [
"realesrgan",
"lora",
"codeformer",
"embeddings",
"controlnet",
]
MODEL_EXTENSIONS = {
"stable-diffusion": [".ckpt", ".safetensors"],
@ -29,6 +32,7 @@ MODEL_EXTENSIONS = {
"lora": [".ckpt", ".safetensors"],
"codeformer": [".pth"],
"embeddings": [".pt", ".bin", ".safetensors"],
"controlnet": [".pth", ".safetensors"],
}
DEFAULT_MODELS = {
"stable-diffusion": [
@ -177,7 +181,8 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models
def resolve_model_paths(models_data: ModelsData):
model_paths = models_data.model_paths
for model_type in model_paths:
if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths
skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"]
if model_type in skip_models: # doesn't use model paths
continue
if model_type == "codeformer":
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
@ -291,6 +296,7 @@ def getModels(scan_for_malicious: bool = True):
"lora": [],
"codeformer": ["codeformer"],
"embeddings": [],
"controlnet": [],
},
}
@ -350,6 +356,7 @@ def getModels(scan_for_malicious: bool = True):
listModels(model_type="gfpgan")
listModels(model_type="lora")
listModels(model_type="embeddings")
listModels(model_type="controlnet")
if scan_for_malicious and models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")

View File

@ -210,6 +210,9 @@ def generate_images_internal(
if req.init_image is not None and not context.test_diffusers:
req.sampler_name = "ddim"
if req.control_image and task_data.control_filter_to_apply:
req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0]
if context.test_diffusers:
pipe = context.models["stable-diffusion"]["default"]
if hasattr(pipe.unet, "_allocate_trt_buffers"):

View File

@ -75,6 +75,7 @@ class TaskData(BaseModel):
use_controlnet_model: Union[str, List[str]] = None
filters: List[str] = []
filter_params: Dict[str, Dict[str, Any]] = {}
control_filter_to_apply: Union[str, List[str]] = None
show_only_filtered_image: bool = False
block_nsfw: bool = False
@ -135,6 +136,7 @@ class GenerateImageResponse:
def json(self):
del self.render_request.init_image
del self.render_request.init_image_mask
del self.render_request.control_image
task_data = self.task_data.dict()
task_data.update(self.output_format.dict())
@ -212,6 +214,9 @@ def convert_legacy_render_req_to_new(old_req: dict):
model_paths["latent_upscaler"] = (
model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None
)
if "control_filter_to_apply" in old_req:
filter_model = old_req["control_filter_to_apply"]
model_paths[filter_model] = filter_model
if old_req.get("block_nsfw"):
model_paths["nsfw_checker"] = "nsfw_checker"

View File

@ -83,8 +83,8 @@
<label for="init_image">Initial Image (img2img) <small>(optional)</small> </label>
<div id="init_image_preview_container" class="image_preview_container">
<div id="init_image_wrapper">
<img id="init_image_preview" src="" crossorigin="anonymous" />
<div id="init_image_wrapper" class="preview_image_wrapper">
<img id="init_image_preview" class="image_preview" src="" crossorigin="anonymous" />
<span id="init_image_size_box" class="img_bottom_label"></span>
<button class="init_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
</div>
@ -151,7 +151,6 @@
<td><label for="convert_to_tensorrt">Enable TensorRT:</label></td>
<td class="diffusers-restart-needed">
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>
<label><small>Takes upto 20 mins the first time</small></label>
</td>
</tr>
@ -162,6 +161,58 @@
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Clip-Skip" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about Clip Skip</span></i></a>
</td>
</tr>
<tr id="controlnet_model_container" class="pl-5"><td><label for="controlnet_model">ControlNet Image:</label></td><td>
<div id="control_image_wrapper" class="preview_image_wrapper">
<img id="control_image_preview" class="image_preview" src="" crossorigin="anonymous" />
<span id="control_image_size_box" class="img_bottom_label"></span>
<button class="control_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
</div>
<input id="control_image" name="control_image" type="file" />
<a href="https://github.com/easydiffusion/easydiffusion/wiki/ControlNet" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about ControlNets</span></i></a>
<div id="controlnet_config" class="displayNone">
<label><small>Filter to apply:</small></label>
<select id="control_image_filter">
<option value="">None</option>
<optgroup label="Pose">
<option value="openpose">OpenPose (*)</option>
<option value="openpose_face">OpenPose face</option>
<option value="openpose_faceonly">OpenPose face-only</option>
<option value="openpose_hand">OpenPose hand</option>
<option value="openpose_full">OpenPose full</option>
</optgroup>
<optgroup label="Outline">
<option value="canny">Canny (*)</option>
<option value="mlsd">Straight lines</option>
<option value="scribble_hed">Scribble hed (*)</option>
<option value="scribble_hedsafe">Scribble hedsafe</option>
<option value="scribble_pidinet">Scribble pidinet</option>
<option value="scribble_pidsafe">Scribble pidsafe</option>
<option value="softedge_hed">Softedge hed</option>
<option value="softedge_hedsafe">Softedge hedsafe</option>
<option value="softedge_pidinet">Softedge pidinet</option>
<option value="softedge_pidsafe">Softedge pidsafe</option>
</optgroup>
<optgroup label="Depth">
<option value="normal_bae">Normal bae (*)</option>
<option value="depth_midas">Depth midas</option>
<option value="depth_zoe">Depth zoe</option>
<option value="depth_leres">Depth leres</option>
<option value="depth_leres++">Depth leres++</option>
</optgroup>
<optgroup label="Line art">
<option value="lineart_coarse">Lineart coarse</option>
<option value="lineart_realistic">Lineart realistic</option>
<option value="lineart_anime">Lineart anime</option>
</optgroup>
<optgroup label="Misc">
<option value="shuffle">Shuffle</option>
<option value="segment">Segment</option>
</optgroup>
</select>
<br/>
<label for="controlnet_model"><small>Model:</small></label> <input id="controlnet_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
</div>
</td></tr>
<tr class="pl-5"><td><label for="vae_model">Custom VAE:</label></td><td>
<input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
<a href="https://github.com/easydiffusion/easydiffusion/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
@ -239,7 +290,7 @@
</select>
<label for="height"><small>(height)</small></label>
<div id="recent-resolutions-container">
<span id="recent-resolutions-button" class="clickable"><i class="fa-solid fa-sliders"><span class="simple-tooltip top-left"> Recent sizes </span></i></span>
<span id="recent-resolutions-button" class="clickable"><i class="fa-solid fa-sliders"><span class="simple-tooltip top-left"> Advanced sizes </span></i></span>
<div id="recent-resolutions-popup" class="displayNone">
<small>Custom size:</small><br>
<input id="custom-width" name="custom-width" type="number" min="128" value="512" onkeypress="preventNonNumericalInput(event)">

View File

@ -794,7 +794,7 @@ div.img-preview img {
margin-bottom: 8px;
}
#init_image_preview_container:not(.has-image) #init_image_wrapper,
#init_image_preview_container:not(.has-image) .preview_image_wrapper,
#init_image_preview_container:not(.has-image) #inpaint_button_container {
display: none;
}
@ -831,14 +831,14 @@ div.img-preview img {
gap: 8px;
}
#init_image_wrapper {
.preview_image_wrapper {
grid-row: span 3;
position: relative;
width: fit-content;
max-height: 150px;
}
#init_image_preview {
.image_preview {
max-height: 150px;
height: 100%;
width: 100%;
@ -1817,3 +1817,13 @@ div#enlarge-buttons {
.imgContainer .spinnerStatus {
font-size: 10pt;
}
#controlnet_model_container small {
color: var(--text-color)
}
#control_image {
width: 130pt;
}
#controlnet_model {
width: 77%;
}

View File

@ -93,6 +93,11 @@ let initImagePreview = document.querySelector("#init_image_preview")
let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview")
let controlImageSelector = document.querySelector("#control_image")
let controlImagePreview = document.querySelector("#control_image_preview")
let controlImageClearBtn = document.querySelector(".control_image_clear")
let controlImageContainer = document.querySelector("#control_image_wrapper")
let controlImageFilterField = document.querySelector("#control_image_filter")
let applyColorCorrectionField = document.querySelector("#apply_color_correction")
let strictMaskBorderField = document.querySelector("#strict_mask_border")
let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting")
@ -114,6 +119,7 @@ let codeformerFidelityField = document.querySelector("#codeformer_fidelity")
let stableDiffusionModelField = new ModelDropdown(document.querySelector("#stable_diffusion_model"), "stable-diffusion")
let clipSkipField = document.querySelector("#clip_skip")
let tilingField = document.querySelector("#tiling")
let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None")
let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None")
let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None")
let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider")
@ -1447,6 +1453,13 @@ function getCurrentUserRequest() {
// TRT is installed
newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked
}
if (controlnetModelField.value !== "" && IMAGE_REGEX.test(controlImagePreview.src)) {
newTask.reqBody.use_controlnet_model = controlnetModelField.value
newTask.reqBody.control_image = controlImagePreview.src
if (controlImageFilterField.value !== "") {
newTask.reqBody.control_filter_to_apply = controlImageFilterField.value
}
}
return newTask
}
@ -1853,6 +1866,20 @@ function onFixFaceModelChange() {
gfpganModelField.addEventListener("change", onFixFaceModelChange)
onFixFaceModelChange()
function onControlnetModelChange() {
let configBox = document.querySelector("#controlnet_config")
if (IMAGE_REGEX.test(controlImagePreview.src)) {
configBox.classList.remove("displayNone")
controlImageContainer.classList.remove("displayNone")
} else {
configBox.classList.add("displayNone")
controlImageContainer.classList.add("displayNone")
}
}
controlImagePreview.addEventListener("load", onControlnetModelChange)
controlImagePreview.addEventListener("unload", onControlnetModelChange)
onControlnetModelChange()
upscaleModelField.disabled = !useUpscalingField.checked
upscaleAmountField.disabled = !useUpscalingField.checked
useUpscalingField.addEventListener("change", function(e) {
@ -2143,6 +2170,44 @@ promptsFromFileBtn.addEventListener("click", function() {
promptsFromFileSelector.click()
})
function loadControlnetImageFromFile() {
if (controlImageSelector.files.length === 0) {
return
}
let reader = new FileReader()
let file = controlImageSelector.files[0]
reader.addEventListener("load", function(event) {
controlImagePreview.src = reader.result
})
if (file) {
reader.readAsDataURL(file)
}
}
controlImageSelector.addEventListener("change", loadControlnetImageFromFile)
function controlImageLoad() {
let w = controlImagePreview.naturalWidth
let h = controlImagePreview.naturalHeight
addImageSizeOption(w)
addImageSizeOption(h)
widthField.value = w
heightField.value = h
widthField.dispatchEvent(new Event("change"))
heightField.dispatchEvent(new Event("change"))
}
controlImagePreview.addEventListener("load", controlImageLoad)
function controlImageUnload() {
controlImageSelector.value = null
controlImagePreview.src = ""
controlImagePreview.dispatchEvent(new Event("unload"))
}
controlImageClearBtn.addEventListener("click", controlImageUnload)
promptsFromFileSelector.addEventListener("change", async function() {
if (promptsFromFileSelector.files.length === 0) {
return

View File

@ -667,4 +667,7 @@ async function getModels(scanForMalicious = true) {
}
// reload models button
document.querySelector("#reload-models").addEventListener("click", () => getModels())
document.querySelector("#reload-models").addEventListener("click", (e) => {
e.stopPropagation()
getModels()
})