diff --git a/CHANGES.md b/CHANGES.md index 1ac36410..141c0773 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -22,6 +22,7 @@ Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed. ### Detailed changelog +* 2.5.48 - 1 Aug 2023 - (beta-only) Full support for ControlNets. You can select a control image to guide the AI. You can pick a filter to pre-process the image, and one of the known (or custom) controlnet models. Supports `OpenPose`, `Canny`, `Straight Lines`, `Depth`, `Line Art`, `Scribble`, `Soft Edge`, `Shuffle` and `Segment`. * 2.5.47 - 30 Jul 2023 - An option to use `Strict Mask Border` while inpainting, to avoid touching areas outside the mask. But this might show a slight outline of the mask, which you will have to touch up separately. * 2.5.47 - 29 Jul 2023 - (beta-only) Fix long prompts with SDXL. * 2.5.47 - 29 Jul 2023 - (beta-only) Fix red dots in some SDXL images. diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 8a739a7e..bc043c7c 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -18,7 +18,7 @@ os_name = platform.system() modules_to_check = { "torch": ("1.11.0", "1.13.1", "2.0.0"), "torchvision": ("0.12.0", "0.14.1", "0.15.1"), - "sdkit": "1.0.153", + "sdkit": "1.0.156", "stable-diffusion-sdkit": "2.1.4", "rich": "12.6.0", "uvicorn": "0.19.0", diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 1ee5ce9d..63f79859 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -9,6 +9,7 @@ from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db +from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.utils import hash_file_quick KNOWN_MODEL_TYPES = [ @@ -19,6 +20,8 @@ KNOWN_MODEL_TYPES = [ "realesrgan", "lora", "codeformer", + "embeddings", + "controlnet", ] MODEL_EXTENSIONS = { "stable-diffusion": [".ckpt", ".safetensors"], @@ -29,6 +32,7 @@ MODEL_EXTENSIONS = { "lora": [".ckpt", ".safetensors"], "codeformer": [".pth"], "embeddings": [".pt", ".bin", ".safetensors"], + "controlnet": [".pth", ".safetensors"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -177,10 +181,17 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models def resolve_model_paths(models_data: ModelsData): model_paths = models_data.model_paths for model_type in model_paths: - if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths + skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"] + if model_type in skip_models: # doesn't use model paths continue if model_type == "codeformer": download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") + elif model_type == "controlnet": + model_id = model_paths[model_type] + model_info = get_model_info_from_db(model_type=model_type, model_id=model_id) + if model_info: + filename = model_info.get("url", "").split("/")[-1] + download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False) model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type) @@ -204,17 +215,17 @@ def download_default_models_if_necessary(): print(model_type, "model(s) found.") -def download_if_necessary(model_type: str, file_name: str, model_id: str): +def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): model_path = os.path.join(app.MODELS_DIR, model_type, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] - other_models_exist = any_model_exists(model_type) + other_models_exist = any_model_exists(model_type) and skip_if_others_exist known_model_exists = os.path.exists(model_path) known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash if known_model_is_corrupt or (not other_models_exist and not known_model_exists): print("> download", model_type, model_id) - download_model(model_type, model_id, download_base_dir=app.MODELS_DIR) + download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False) def migrate_legacy_model_location(): @@ -285,12 +296,26 @@ def is_malicious_model(file_path): def getModels(scan_for_malicious: bool = True): models = { "options": { - "stable-diffusion": ["sd-v1-4"], + "stable-diffusion": [{"sd-v1-4": "SD 1.4"}], "vae": [], "hypernetwork": [], "lora": [], - "codeformer": ["codeformer"], + "codeformer": [{"codeformer": "CodeFormer"}], "embeddings": [], + "controlnet": [ + {"control_v11p_sd15_canny": "Canny (*)"}, + {"control_v11p_sd15_openpose": "OpenPose (*)"}, + {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, + {"control_v11f1p_sd15_depth": "Depth (*)"}, + {"control_v11p_sd15_scribble": "Scribble"}, + {"control_v11p_sd15_softedge": "Soft Edge"}, + {"control_v11p_sd15_inpaint": "Inpaint"}, + {"control_v11p_sd15_lineart": "Line Art"}, + {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, + {"control_v11p_sd15_mlsd": "Straight Lines"}, + {"control_v11p_sd15_seg": "Segment"}, + {"control_v11e_sd15_shuffle": "Shuffle"}, + ], }, } @@ -299,9 +324,9 @@ def getModels(scan_for_malicious: bool = True): class MaliciousModelException(Exception): "Raised when picklescan reports a problem with a model" - def scan_directory(directory, suffixes, directoriesFirst: bool = True): + def scan_directory(directory, suffixes, directoriesFirst: bool = True, default_entries=[]): + tree = list(default_entries) nonlocal models_scanned - tree = [] for entry in sorted( os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()), @@ -320,7 +345,14 @@ def getModels(scan_for_malicious: bool = True): raise MaliciousModelException(entry.path) if scan_for_malicious: known_models[entry.path] = mtime - tree.append(entry.name[: -len(matching_suffix)]) + model_id = entry.name[: -len(matching_suffix)] + model_exists = False + for m in tree: # allows default "named" models, like CodeFormer and known ControlNet models + if (isinstance(m, str) and model_id == m) or (isinstance(m, dict) and model_id in m): + model_exists = True + break + if not model_exists: + tree.append(model_id) elif entry.is_dir(): scan = scan_directory(entry.path, suffixes, directoriesFirst=False) @@ -337,7 +369,8 @@ def getModels(scan_for_malicious: bool = True): os.makedirs(models_dir) try: - models["options"][model_type] = scan_directory(models_dir, model_extensions) + default_tree = models["options"].get(model_type, []) + models["options"][model_type] = scan_directory(models_dir, model_extensions, default_entries=default_tree) except MaliciousModelException as e: models["scan-error"] = str(e) @@ -350,6 +383,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="gfpgan") listModels(model_type="lora") listModels(model_type="embeddings") + listModels(model_type="controlnet") if scan_for_malicious and models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index bbc36aa5..8df208b6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -63,7 +63,7 @@ class RenderTask(Task): if ( runtime.set_vram_optimizations(context) or self.has_param_changed(context, "clip_skip") - or self.has_param_changed(context, "convert_to_tensorrt") + or self.trt_needs_reload(context) ): models_to_force_reload.append("stable-diffusion") @@ -92,6 +92,17 @@ class RenderTask(Task): new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) return model["params"].get(param_name) != new_val + def trt_needs_reload(self, context): + if not self.has_param_changed(context, "convert_to_tensorrt"): + return False + + model = context.models["stable-diffusion"] + pipe = model["default"] + if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded + return False + + return True + def make_images( context, @@ -148,6 +159,7 @@ def make_images_internal( context, req, task_data, + models_data, data_queue, task_temp_images, step_callback, @@ -174,6 +186,7 @@ def generate_images_internal( context, req: GenerateImageRequest, task_data: TaskData, + models_data: ModelsData, data_queue: queue.Queue, task_temp_images: list, step_callback, @@ -197,6 +210,18 @@ def generate_images_internal( if req.init_image is not None and not context.test_diffusers: req.sampler_name = "ddim" + if req.control_image and task_data.control_filter_to_apply: + req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0] + + if context.test_diffusers: + pipe = context.models["stable-diffusion"]["default"] + if hasattr(pipe.unet, "_allocate_trt_buffers"): + convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False) + pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward + # pipe.vae.decoder.forward = ( + # pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward + # ) + images = generate_images(context, callback=callback, **req.dict()) user_stopped = False except UserInitiatedStop: diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 894867b8..181a9505 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -75,6 +75,7 @@ class TaskData(BaseModel): use_controlnet_model: Union[str, List[str]] = None filters: List[str] = [] filter_params: Dict[str, Dict[str, Any]] = {} + control_filter_to_apply: Union[str, List[str]] = None show_only_filtered_image: bool = False block_nsfw: bool = False @@ -135,6 +136,7 @@ class GenerateImageResponse: def json(self): del self.render_request.init_image del self.render_request.init_image_mask + del self.render_request.control_image task_data = self.task_data.dict() task_data.update(self.output_format.dict()) @@ -212,6 +214,9 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["latent_upscaler"] = ( model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None ) + if "control_filter_to_apply" in old_req: + filter_model = old_req["control_filter_to_apply"] + model_paths[filter_model] = filter_model if old_req.get("block_nsfw"): model_paths["nsfw_checker"] = "nsfw_checker" diff --git a/ui/index.html b/ui/index.html index de4273aa..508b4637 100644 --- a/ui/index.html +++ b/ui/index.html @@ -34,7 +34,7 @@
Click to learn more about custom models | ||
+ | - Click to learn more about TensorRT - + | |
+
+
+
+ Click to learn more about ControlNets
+
+
+
+
+ + + + + | ||
Click to learn more about VAEs
@@ -241,7 +294,7 @@
-
+
Custom size: @@ -737,7 +790,7 @@ async function init() { ping: onPing } }) - splashScreen() + // splashScreen() // load models again, but scan for malicious this time await getModels(true) diff --git a/ui/media/css/main.css b/ui/media/css/main.css index acdae9a9..b6dc85c4 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -794,7 +794,7 @@ div.img-preview img { margin-bottom: 8px; } -#init_image_preview_container:not(.has-image) #init_image_wrapper, +#init_image_preview_container:not(.has-image) .preview_image_wrapper, #init_image_preview_container:not(.has-image) #inpaint_button_container { display: none; } @@ -831,14 +831,14 @@ div.img-preview img { gap: 8px; } -#init_image_wrapper { +.preview_image_wrapper { grid-row: span 3; position: relative; width: fit-content; max-height: 150px; } -#init_image_preview { +.image_preview { max-height: 150px; height: 100%; width: 100%; @@ -1843,3 +1843,13 @@ div#enlarge-buttons { .imgContainer .spinnerStatus { font-size: 10pt; } + +#controlnet_model_container small { + color: var(--text-color) +} +#control_image { + width: 130pt; +} +#controlnet_model { + width: 77%; +} \ No newline at end of file diff --git a/ui/media/js/main.js b/ui/media/js/main.js index effd9660..6a5c92fc 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -93,6 +93,11 @@ let initImagePreview = document.querySelector("#init_image_preview") let initImageSizeBox = document.querySelector("#init_image_size_box") let maskImageSelector = document.querySelector("#mask") let maskImagePreview = document.querySelector("#mask_preview") +let controlImageSelector = document.querySelector("#control_image") +let controlImagePreview = document.querySelector("#control_image_preview") +let controlImageClearBtn = document.querySelector(".control_image_clear") +let controlImageContainer = document.querySelector("#control_image_wrapper") +let controlImageFilterField = document.querySelector("#control_image_filter") let applyColorCorrectionField = document.querySelector("#apply_color_correction") let strictMaskBorderField = document.querySelector("#strict_mask_border") let colorCorrectionSetting = document.querySelector("#apply_color_correction_setting") @@ -114,6 +119,7 @@ let codeformerFidelityField = document.querySelector("#codeformer_fidelity") let stableDiffusionModelField = new ModelDropdown(document.querySelector("#stable_diffusion_model"), "stable-diffusion") let clipSkipField = document.querySelector("#clip_skip") let tilingField = document.querySelector("#tiling") +let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false) let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None") let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None") let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider") @@ -1378,9 +1384,19 @@ function createTask(task) { function getCurrentUserRequest() { const numOutputsTotal = parseInt(numOutputsTotalField.value) - const numOutputsParallel = parseInt(numOutputsParallelField.value) + let numOutputsParallel = parseInt(numOutputsParallelField.value) const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value) + if ( + testDiffusers.checked && + document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall" && + document.querySelector("#convert_to_tensorrt").checked + ) { + // TRT enabled + + numOutputsParallel = 1 // force 1 parallel + } + const newTask = { batchesDone: 0, numOutputsTotal: numOutputsTotal, @@ -1469,6 +1485,13 @@ function getCurrentUserRequest() { // TRT is installed newTask.reqBody.convert_to_tensorrt = document.querySelector("#convert_to_tensorrt").checked } + if (controlnetModelField.value !== "" && IMAGE_REGEX.test(controlImagePreview.src)) { + newTask.reqBody.use_controlnet_model = controlnetModelField.value + newTask.reqBody.control_image = controlImagePreview.src + if (controlImageFilterField.value !== "") { + newTask.reqBody.control_filter_to_apply = controlImageFilterField.value + } + } return newTask } @@ -1875,6 +1898,51 @@ function onFixFaceModelChange() { gfpganModelField.addEventListener("change", onFixFaceModelChange) onFixFaceModelChange() +function onControlnetModelChange() { + let configBox = document.querySelector("#controlnet_config") + if (IMAGE_REGEX.test(controlImagePreview.src)) { + configBox.classList.remove("displayNone") + controlImageContainer.classList.remove("displayNone") + } else { + configBox.classList.add("displayNone") + controlImageContainer.classList.add("displayNone") + } +} +controlImagePreview.addEventListener("load", onControlnetModelChange) +controlImagePreview.addEventListener("unload", onControlnetModelChange) +onControlnetModelChange() + +function onControlImageFilterChange() { + let filterId = controlImageFilterField.value + if (filterId.includes("openpose")) { + controlnetModelField.value = "control_v11p_sd15_openpose" + } else if (filterId === "canny") { + controlnetModelField.value = "control_v11p_sd15_canny" + } else if (filterId === "mlsd") { + controlnetModelField.value = "control_v11p_sd15_mlsd" + } else if (filterId === "mlsd") { + controlnetModelField.value = "control_v11p_sd15_mlsd" + } else if (filterId.includes("scribble")) { + controlnetModelField.value = "control_v11p_sd15_scribble" + } else if (filterId.includes("softedge")) { + controlnetModelField.value = "control_v11p_sd15_softedge" + } else if (filterId === "normal_bae") { + controlnetModelField.value = "control_v11p_sd15_normalbae" + } else if (filterId.includes("depth")) { + controlnetModelField.value = "control_v11f1p_sd15_depth" + } else if (filterId === "lineart_anime") { + controlnetModelField.value = "control_v11p_sd15s2_lineart_anime" + } else if (filterId.includes("lineart")) { + controlnetModelField.value = "control_v11p_sd15_lineart" + } else if (filterId === "shuffle") { + controlnetModelField.value = "control_v11e_sd15_shuffle" + } else if (filterId === "segment") { + controlnetModelField.value = "control_v11p_sd15_seg" + } +} +controlImageFilterField.addEventListener("change", onControlImageFilterChange) +onControlImageFilterChange() + upscaleModelField.disabled = !useUpscalingField.checked upscaleAmountField.disabled = !useUpscalingField.checked useUpscalingField.addEventListener("change", function(e) { @@ -2165,6 +2233,44 @@ promptsFromFileBtn.addEventListener("click", function() { promptsFromFileSelector.click() }) +function loadControlnetImageFromFile() { + if (controlImageSelector.files.length === 0) { + return + } + + let reader = new FileReader() + let file = controlImageSelector.files[0] + + reader.addEventListener("load", function(event) { + controlImagePreview.src = reader.result + }) + + if (file) { + reader.readAsDataURL(file) + } +} +controlImageSelector.addEventListener("change", loadControlnetImageFromFile) + +function controlImageLoad() { + let w = controlImagePreview.naturalWidth + let h = controlImagePreview.naturalHeight + addImageSizeOption(w) + addImageSizeOption(h) + + widthField.value = w + heightField.value = h + widthField.dispatchEvent(new Event("change")) + heightField.dispatchEvent(new Event("change")) +} +controlImagePreview.addEventListener("load", controlImageLoad) + +function controlImageUnload() { + controlImageSelector.value = null + controlImagePreview.src = "" + controlImagePreview.dispatchEvent(new Event("unload")) +} +controlImageClearBtn.addEventListener("click", controlImageUnload) + promptsFromFileSelector.addEventListener("change", async function() { if (promptsFromFileSelector.files.length === 0) { return @@ -2293,6 +2399,8 @@ function tunnelUpdate(event) { } } +let trtSettingsForced = false + function packagesUpdate(event) { let trtBtn = document.getElementById("toggle-tensorrt-install") let trtInstalled = "packages_installed" in event && "tensorrt" in event["packages_installed"] @@ -2307,6 +2415,22 @@ function packagesUpdate(event) { if (document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall") { document.querySelector("#enable_trt_config").classList.remove("displayNone") + + if (!trtSettingsForced) { + // settings for demo + promptField.value = "Dragons fighting with a knight, castle, war scene, fantasy, cartoon, flames, HD" + seedField.value = 3187947173 + widthField.value = 1024 + heightField.value = 768 + randomSeedField.checked = false + seedField.disabled = false + stableDiffusionModelField.value = "sd-v1-4" + + numOutputsParallelField.classList.add("displayNone") + document.querySelector("#num_outputs_parallel_label").classList.add("displayNone") + + trtSettingsForced = true + } } } diff --git a/ui/media/js/searchable-models.js b/ui/media/js/searchable-models.js index 85b9bbe9..299c60dc 100644 --- a/ui/media/js/searchable-models.js +++ b/ui/media/js/searchable-models.js @@ -552,17 +552,23 @@ class ModelDropdown { this.createModelNodeList(`${folderName || ""}/${childFolderName}`, childModels, false) ) } else { + let modelId = model + let modelName = model + if (typeof model === "object") { + modelId = Object.keys(model)[0] + modelName = model[modelId] + } const classes = ["model-file"] if (isRootFolder) { classes.push("in-root-folder") } // Remove the leading slash from the model path - const fullPath = folderName ? `${folderName.substring(1)}/${model}` : model + const fullPath = folderName ? `${folderName.substring(1)}/${modelId}` : modelId modelsMap.set( - model, + modelId, createElement("li", { "data-path": fullPath }, classes, [ createElement("i", undefined, ["fa-regular", "fa-file", "icon"]), - model, + modelName, ]) ) } @@ -643,22 +649,6 @@ async function getModels(scanForMalicious = true) { makeImageBtn.disabled = true } - /* This code should no longer be needed. Commenting out for now, will cleanup later. - const sd_model_setting_key = "stable_diffusion_model" - const vae_model_setting_key = "vae_model" - const hypernetwork_model_key = "hypernetwork_model" - - const stableDiffusionOptions = modelsOptions['stable-diffusion'] - const vaeOptions = modelsOptions['vae'] - const hypernetworkOptions = modelsOptions['hypernetwork'] - - // TODO: set default for model here too - SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0] - if (getSetting(sd_model_setting_key) == '' || SETTINGS[sd_model_setting_key].value == '') { - setSetting(sd_model_setting_key, stableDiffusionOptions[0]) - } - */ - // notify ModelDropdown objects to refresh document.dispatchEvent(new Event("refreshModels")) } catch (e) { @@ -667,4 +657,7 @@ async function getModels(scanForMalicious = true) { } // reload models button -document.querySelector("#reload-models").addEventListener("click", () => getModels()) +document.querySelector("#reload-models").addEventListener("click", (e) => { + e.stopPropagation() + getModels() +}) |