From 889a070e62844bd08be822009716277f50ce5818 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 14 Jul 2025 13:20:26 +0530 Subject: [PATCH] Support custom text encoders and Flux VAEs in the UI --- scripts/check_modules.py | 4 +- ui/easydiffusion/backends/webui/__init__.py | 2 + ui/easydiffusion/backends/webui/impl.py | 74 +++++++++++---------- ui/easydiffusion/model_manager.py | 72 +++++++++++++++----- ui/easydiffusion/types.py | 2 + ui/index.html | 10 ++- ui/media/js/auto-save.js | 1 + ui/media/js/dnd.js | 40 +++++++++++ ui/media/js/main.js | 60 ++++++++++++++--- ui/media/js/multi-model-selector.js | 31 ++++++--- 10 files changed, 218 insertions(+), 78 deletions(-) diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 69877650..bda95cb1 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -124,10 +124,10 @@ def update_modules(): # if sdkit is 2.0.15.x (or lower), then diffusers should be restricted to 0.21.4 (see below for the reason) # otherwise use the current sdkit version (with the corresponding diffusers version) - expected_sdkit_version_str = "2.0.22.8" + expected_sdkit_version_str = "2.0.22.9" expected_diffusers_version_str = "0.28.2" - legacy_sdkit_version_str = "2.0.15.17" + legacy_sdkit_version_str = "2.0.15.18" legacy_diffusers_version_str = "0.21.4" sdkit_version_str = version("sdkit") diff --git a/ui/easydiffusion/backends/webui/__init__.py b/ui/easydiffusion/backends/webui/__init__.py index b8cef154..54a43f4f 100644 --- a/ui/easydiffusion/backends/webui/__init__.py +++ b/ui/easydiffusion/backends/webui/__init__.py @@ -18,6 +18,7 @@ from .impl import ( ping, load_model, unload_model, + flush_model_changes, set_options, generate_images, filter_images, @@ -53,6 +54,7 @@ MODELS_TO_OVERRIDE = { "codeformer": "--codeformer-models-path", "embeddings": "--embeddings-dir", "controlnet": "--controlnet-dir", + "text-encoder": "--text-encoder-dir", } WEBUI_PATCHES = [ diff --git a/ui/easydiffusion/backends/webui/impl.py b/ui/easydiffusion/backends/webui/impl.py index afc5cd9e..7e21cf4c 100644 --- a/ui/easydiffusion/backends/webui/impl.py +++ b/ui/easydiffusion/backends/webui/impl.py @@ -27,6 +27,7 @@ webui_opts: dict = None curr_models = { "stable-diffusion": None, "vae": None, + "text-encoder": None, } @@ -96,50 +97,51 @@ def ping(timeout=1): def load_model(context, model_type, **kwargs): + from easydiffusion.app import ROOT_DIR, getConfig + + config = getConfig() + models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models")) + model_path = context.model_paths[model_type] + if model_type == "stable-diffusion": + base_dir = os.path.join(models_dir, model_type) + model_path = os.path.relpath(model_path, base_dir) + + # print(f"load model: {model_type=} {model_path=} {curr_models=}") + curr_models[model_type] = model_path + + +def unload_model(context, model_type, **kwargs): + # print(f"unload model: {model_type=} {curr_models=}") + curr_models[model_type] = None + + +def flush_model_changes(context): if webui_opts is None: print("Server not ready, can't set the model") return - if model_type == "stable-diffusion": - model_name = os.path.basename(model_path) - model_name = os.path.splitext(model_name)[0] - print(f"setting sd model: {model_name}") - if curr_models[model_type] != model_name: - try: - res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name}) - if res.status_code != 200: - raise Exception(res.text) - except Exception as e: - raise RuntimeError( - f"The engine failed to set the required options. Please check the logs in the command line window for more details." - ) + modules = [] + for model_type in ("vae", "text-encoder"): + if curr_models[model_type]: + model_paths = curr_models[model_type] + model_paths = [model_paths] if not isinstance(model_paths, list) else model_paths + modules += model_paths - curr_models[model_type] = model_name - elif model_type == "vae": - if curr_models[model_type] != model_path: - vae_model = [model_path] if model_path else [] + opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": modules} - opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model} - print("setting opts 2", opts) + print("Setting backend models", opts) - try: - res = webui_post("/sdapi/v1/options", json=opts) - if res.status_code != 200: - raise Exception(res.text) - except Exception as e: - raise RuntimeError( - f"The engine failed to set the required options. Please check the logs in the command line window for more details." - ) - - curr_models[model_type] = model_path - - -def unload_model(context, model_type, **kwargs): - if model_type == "vae": - context.model_paths[model_type] = None - load_model(context, model_type) + try: + res = webui_post("/sdapi/v1/options", json=opts) + print("got res", res.status_code) + if res.status_code != 200: + raise Exception(res.text) + except Exception as e: + raise RuntimeError( + f"The engine failed to set the required options. Please check the logs in the command line window for more details." + ) def generate_images( @@ -346,7 +348,7 @@ def refresh_models(): pass try: - for type in ("checkpoints", "vae"): + for type in ("checkpoints", "vae-and-text-encoders"): t = Thread(target=make_refresh_call, args=(type,)) t.start() except Exception as e: diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 45b38b96..0fe5a2c9 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -3,6 +3,7 @@ import shutil from glob import glob import traceback from typing import Union +from os import path from easydiffusion import app from easydiffusion.types import ModelsData @@ -22,6 +23,7 @@ KNOWN_MODEL_TYPES = [ "codeformer", "embeddings", "controlnet", + "text-encoder", ] MODEL_EXTENSIONS = { "stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"], @@ -33,6 +35,7 @@ MODEL_EXTENSIONS = { "codeformer": [".pth"], "embeddings": [".pt", ".bin", ".safetensors", ".sft"], "controlnet": [".pth", ".safetensors", ".sft"], + "text-encoder": [".safetensors", ".sft"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -59,6 +62,7 @@ ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility "realesrgan": "RealESRGAN", "lora": "Lora", "controlnet": "ControlNet", + "text-encoder": "text_encoder", } known_models = {} @@ -204,6 +208,9 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models context.model_load_errors = {} context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks + if hasattr(backend, "flush_model_changes"): + backend.flush_model_changes(context) + def resolve_model_paths(models_data: ModelsData): from easydiffusion.backend_manager import backend @@ -224,14 +231,25 @@ def resolve_model_paths(models_data: ModelsData): for model_type in model_paths: if model_type in skip_models: # doesn't use model paths continue - if model_type == "codeformer" and model_paths[model_type]: - download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") - elif model_type == "controlnet" and model_paths[model_type]: - model_id = model_paths[model_type] - model_info = get_model_info_from_db(model_type=model_type, model_id=model_id) - if model_info: - filename = model_info.get("url", "").split("/")[-1] - download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False) + + if model_type in ("vae", "codeformer", "controlnet", "text-encoder") and model_paths[model_type]: + model_ids = model_paths[model_type] + model_ids = model_ids if isinstance(model_ids, list) else [model_ids] + + new_model_paths = [] + + for model_id in model_ids: + log.info(f"Checking for {model_id=}") + model_info = get_model_info_from_db(model_type=model_type, model_id=model_id) + if model_info: + filename = model_info.get("url", "").split("/")[-1] + download_if_necessary(model_type, filename, model_id, skip_if_others_exist=False) + + new_model_paths.append(path.splitext(filename)[0]) + else: # not in the model db, probably a regular file + new_model_paths.append(model_id) + + model_paths[model_type] = new_model_paths model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type) @@ -256,17 +274,31 @@ def download_default_models_if_necessary(): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): - model_dir = get_model_dirs(model_type)[0] - model_path = os.path.join(model_dir, file_name) + from easydiffusion.backend_manager import backend + expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] - other_models_exist = any_model_exists(model_type) and skip_if_others_exist - known_model_exists = os.path.exists(model_path) - known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash - if known_model_is_corrupt or (not other_models_exist and not known_model_exists): - print("> download", model_type, model_id) - download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False) + for model_dir in get_model_dirs(model_type): + model_path = os.path.join(model_dir, file_name) + + known_model_exists = os.path.exists(model_path) + known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash + + needs_download = known_model_is_corrupt or (not other_models_exist and not known_model_exists) + + log.info(f"{model_path=} {needs_download=}") + if known_model_exists: + log.info(f"{expected_hash=} {hash_file_quick(model_path)=}") + log.info(f"{known_model_is_corrupt=} {other_models_exist=} {known_model_exists=}") + + if not needs_download: + return + + print("> download", model_type, model_id) + download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False) + + backend.refresh_models() def migrate_legacy_model_location(): @@ -363,7 +395,7 @@ def getModels(scan_for_malicious: bool = True): models = { "options": { "stable-diffusion": [], - "vae": [], + "vae": [{"ae": "ae (Flux VAE fp16)"}], "hypernetwork": [], "lora": [], "codeformer": [{"codeformer": "CodeFormer"}], @@ -383,6 +415,11 @@ def getModels(scan_for_malicious: bool = True): # {"control_v11e_sd15_shuffle": "Shuffle"}, # {"control_v11f1e_sd15_tile": "Tile"}, ], + "text-encoder": [ + {"t5xxl_fp16": "T5 XXL fp16"}, + {"clip_l": "CLIP L"}, + {"clip_g": "CLIP G"}, + ], }, } @@ -466,6 +503,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="lora") listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="controlnet") + listModels(model_type="text-encoder") if scan_for_malicious and models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index bc0ccabf..b9ffe283 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -80,6 +80,7 @@ class RenderTaskData(TaskData): latent_upscaler_steps: int = 10 use_stable_diffusion_model: Union[str, List[str]] = "sd-v1-4" use_vae_model: Union[str, List[str]] = None + use_text_encoder_model: Union[str, List[str]] = None use_hypernetwork_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None use_controlnet_model: Union[str, List[str]] = None @@ -211,6 +212,7 @@ def convert_legacy_render_req_to_new(old_req: dict): # move the model info model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model") model_paths["vae"] = old_req.get("use_vae_model") + model_paths["text-encoder"] = old_req.get("use_text_encoder_model") model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model") model_paths["lora"] = old_req.get("use_lora_model") model_paths["controlnet"] = old_req.get("use_controlnet_model") diff --git a/ui/index.html b/ui/index.html index e26b4693..b1c02ed6 100644 --- a/ui/index.html +++ b/ui/index.html @@ -302,6 +302,14 @@ Click to learn more about VAEs + + + + + +
+ +
- + diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 7d8b864e..5553e78f 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -60,6 +60,7 @@ const SETTINGS_IDS_LIST = [ "extract_lora_from_prompt", "embedding-card-size-selector", "lora_model", + "text_encoder_model", "enable_vae_tiling", "controlnet_alpha", ] diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index cc15cc35..f5f9bb08 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -394,6 +394,45 @@ const TASK_MAPPING = { return val }, }, + use_text_encoder_model: { + name: "Text Encoder model", + setUI: (use_text_encoder_model) => { + let modelPaths = [] + use_text_encoder_model = use_text_encoder_model === null ? "" : use_text_encoder_model + use_text_encoder_model = Array.isArray(use_text_encoder_model) ? use_text_encoder_model : [use_text_encoder_model] + use_text_encoder_model.forEach((m) => { + if (m.includes("models\\text-encoder\\")) { + m = m.split("models\\text-encoder\\")[1] + } else if (m.includes("models\\\\text-encoder\\\\")) { + m = m.split("models\\\\text-encoder\\\\")[1] + } else if (m.includes("models/text-encoder/")) { + m = m.split("models/text-encoder/")[1] + } + m = m.replaceAll("\\\\", "/") + m = getModelPath(m, [".safetensors", ".sft"]) + modelPaths.push(m) + }) + text_encoderModelField.modelNames = modelPaths + }, + readUI: () => { + return text_encoderModelField.modelNames + }, + parse: (val) => { + val = !val || val === "None" ? "" : val + if (typeof val === "string" && val.includes(",")) { + val = val.split(",") + val = val.map((v) => v.trim()) + val = val.map((v) => v.replaceAll("\\", "\\\\")) + val = val.map((v) => v.replaceAll('"', "")) + val = val.map((v) => v.replaceAll("'", "")) + val = val.map((v) => '"' + v + '"') + val = "[" + val + "]" + val = JSON.parse(val) + } + val = Array.isArray(val) ? val : [val] + return val + }, + }, use_hypernetwork_model: { name: "Hypernetwork model", setUI: (use_hypernetwork_model) => { @@ -620,6 +659,7 @@ const TASK_TEXT_MAPPING = { hypernetwork_strength: "Hypernetwork Strength", use_lora_model: "LoRA model", lora_alpha: "LoRA Strength", + use_text_encoder_model: "Text Encoder model", use_controlnet_model: "ControlNet model", control_filter_to_apply: "ControlNet Filter", control_alpha: "ControlNet Strength", diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 755bca2f..95cef636 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -54,6 +54,7 @@ const taskConfigSetup = { label: "Hypernetwork Strength", visible: ({ reqBody }) => !!reqBody?.use_hypernetwork_model, }, + use_text_encoder_model: { label: "Text Encoder", visible: ({ reqBody }) => !!reqBody?.use_text_encoder_model }, use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model }, lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model }, preserve_init_image_color_profile: "Preserve Color Profile", @@ -141,6 +142,7 @@ let tilingField = document.querySelector("#tiling") let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false) let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None") let loraModelField = new MultiModelSelector(document.querySelector("#lora_model"), "lora", "LoRA", 0.5, 0.02) +let textEncoderModelField = new MultiModelSelector(document.querySelector("#text_encoder_model"), "text-encoder", "Text Encoder", 0.5, 0.02, false) let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None") let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider") let hypernetworkStrengthField = document.querySelector("#hypernetwork_strength") @@ -1396,6 +1398,7 @@ function getCurrentUserRequest() { newTask.reqBody.hypernetwork_strength = parseFloat(hypernetworkStrengthField.value) } if (testDiffusers.checked) { + // lora let loraModelData = loraModelField.value let modelNames = loraModelData["modelNames"] let modelStrengths = loraModelData["modelWeights"] @@ -1408,6 +1411,16 @@ function getCurrentUserRequest() { newTask.reqBody.lora_alpha = modelStrengths } + // text encoder + let textEncoderModelNames = textEncoderModelField.modelNames + + if (textEncoderModelNames.length > 0) { + textEncoderModelNames = textEncoderModelNames.length == 1 ? textEncoderModelNames[0] : textEncoderModelNames + + newTask.reqBody.use_text_encoder_model = textEncoderModelNames + } + + // vae tiling if (tilingField.value !== "none") { newTask.reqBody.tiling = tilingField.value } @@ -1891,8 +1904,31 @@ document.addEventListener("refreshModels", function() { onControlnetModelChange() }) -// tip for Flux +// utilities for Flux and Chroma let sdModelField = document.querySelector("#stable_diffusion_model") + +// function checkAndSetDependentModels() { +// let sdModel = sdModelField.value.toLowerCase() +// let isFlux = sdModel.includes("flux") +// let isChroma = sdModel.includes("chroma") + +// if (isFlux || isChroma) { +// vaeModelField.value = "ae" + +// if (isFlux) { +// textEncoderModelField.modelNames = ["t5xxl_fp16", "clip_l"] +// } else { +// textEncoderModelField.modelNames = ["t5xxl_fp16"] +// } +// } else { +// if (vaeModelField.value == "ae") { +// vaeModelField.value = "" +// } +// textEncoderModelField.modelNames = [] +// } +// } +// sdModelField.addEventListener("change", checkAndSetDependentModels) + function checkGuidanceValue() { let guidance = parseFloat(guidanceScaleField.value) let guidanceWarning = document.querySelector("#guidanceWarning") @@ -1917,15 +1953,16 @@ sdModelField.addEventListener("change", checkGuidanceValue) guidanceScaleField.addEventListener("change", checkGuidanceValue) guidanceScaleSlider.addEventListener("change", checkGuidanceValue) -function checkGuidanceScaleVisibility() { - let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container") - if (sdModelField.value.toLowerCase().includes("flux")) { - guidanceScaleContainer.classList.remove("displayNone") - } else { - guidanceScaleContainer.classList.add("displayNone") - } -} -sdModelField.addEventListener("change", checkGuidanceScaleVisibility) +// disabling until we can detect flux models more reliably +// function checkGuidanceScaleVisibility() { +// let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container") +// if (sdModelField.value.toLowerCase().includes("flux")) { +// guidanceScaleContainer.classList.remove("displayNone") +// } else { +// guidanceScaleContainer.classList.add("displayNone") +// } +// } +// sdModelField.addEventListener("change", checkGuidanceScaleVisibility) function checkFluxSampler() { let samplerWarning = document.querySelector("#fluxSamplerWarning") @@ -1980,8 +2017,9 @@ schedulerField.addEventListener("change", checkFluxSchedulerSteps) numInferenceStepsField.addEventListener("change", checkFluxSchedulerSteps) document.addEventListener("refreshModels", function() { + // checkAndSetDependentModels() checkGuidanceValue() - checkGuidanceScaleVisibility() + // checkGuidanceScaleVisibility() checkFluxSampler() checkFluxScheduler() checkFluxSchedulerSteps() diff --git a/ui/media/js/multi-model-selector.js b/ui/media/js/multi-model-selector.js index 0640288f..349e3782 100644 --- a/ui/media/js/multi-model-selector.js +++ b/ui/media/js/multi-model-selector.js @@ -10,6 +10,7 @@ class MultiModelSelector { root modelType modelNameFriendly + showWeights defaultWeight weightStep @@ -35,13 +36,13 @@ class MultiModelSelector { if (typeof modelData !== "object") { throw new Error("Multi-model selector expects an object containing modelNames and modelWeights as keys!") } - if (!("modelNames" in modelData) || !("modelWeights" in modelData)) { + if (!("modelNames" in modelData) || (this.showWeights && !("modelWeights" in modelData))) { throw new Error("modelNames or modelWeights not present in the data passed to the multi-model selector") } let newModelNames = modelData["modelNames"] let newModelWeights = modelData["modelWeights"] - if (newModelNames.length !== newModelWeights.length) { + if (newModelWeights && newModelNames.length !== newModelWeights.length) { throw new Error("Need to pass an equal number of modelNames and modelWeights!") } @@ -50,7 +51,7 @@ class MultiModelSelector { // the root of all this unholiness is because searchable-models automatically dispatches an update event // as soon as the value is updated via JS, which is against the DOM pattern of not dispatching an event automatically // unless the caller explicitly dispatches the event. - this.modelWeights = newModelWeights + this.modelWeights = newModelWeights || [] this.modelNames = newModelNames } get disabled() { @@ -91,10 +92,11 @@ class MultiModelSelector { } } - constructor(root, modelType, modelNameFriendly = undefined, defaultWeight = 0.5, weightStep = 0.02) { + constructor(root, modelType, modelNameFriendly = undefined, defaultWeight = 0.5, weightStep = 0.02, showWeights = true) { this.root = root this.modelType = modelType this.modelNameFriendly = modelNameFriendly || modelType + this.showWeights = showWeights this.defaultWeight = defaultWeight this.weightStep = weightStep @@ -135,10 +137,13 @@ class MultiModelSelector { const modelElement = document.createElement("div") modelElement.className = "model_entry" - modelElement.innerHTML = ` - - - ` + let html = `` + + if (this.showWeights) { + html += `` + } + modelElement.innerHTML = html + this.modelContainer.appendChild(modelElement) let modelNameEl = modelElement.querySelector(".model_name") @@ -160,8 +165,8 @@ class MultiModelSelector { modelNameEl.addEventListener("change", makeUpdateEvent("change")) modelNameEl.addEventListener("input", makeUpdateEvent("input")) - modelWeightEl.addEventListener("change", makeUpdateEvent("change")) - modelWeightEl.addEventListener("input", makeUpdateEvent("input")) + modelWeightEl?.addEventListener("change", makeUpdateEvent("change")) + modelWeightEl?.addEventListener("input", makeUpdateEvent("input")) let removeBtn = document.createElement("button") removeBtn.className = "remove_model_btn" @@ -218,10 +223,14 @@ class MultiModelSelector { } get modelWeights() { - return this.getModelElements(true).map((e) => e.weight.value) + return this.getModelElements(true).map((e) => e.weight?.value) } set modelWeights(newModelWeights) { + if (!this.showWeights) { + return + } + this.resizeEntryList(newModelWeights.length) if (newModelWeights.length === 0) {