Support custom text encoders and Flux VAEs in the UI

This commit is contained in:
cmdr2
2025-07-14 13:20:26 +05:30
parent 497b996ce9
commit 889a070e62
10 changed files with 218 additions and 78 deletions

View File

@@ -124,10 +124,10 @@ def update_modules():
# if sdkit is 2.0.15.x (or lower), then diffusers should be restricted to 0.21.4 (see below for the reason) # if sdkit is 2.0.15.x (or lower), then diffusers should be restricted to 0.21.4 (see below for the reason)
# otherwise use the current sdkit version (with the corresponding diffusers version) # otherwise use the current sdkit version (with the corresponding diffusers version)
expected_sdkit_version_str = "2.0.22.8" expected_sdkit_version_str = "2.0.22.9"
expected_diffusers_version_str = "0.28.2" expected_diffusers_version_str = "0.28.2"
legacy_sdkit_version_str = "2.0.15.17" legacy_sdkit_version_str = "2.0.15.18"
legacy_diffusers_version_str = "0.21.4" legacy_diffusers_version_str = "0.21.4"
sdkit_version_str = version("sdkit") sdkit_version_str = version("sdkit")

View File

@@ -18,6 +18,7 @@ from .impl import (
ping, ping,
load_model, load_model,
unload_model, unload_model,
flush_model_changes,
set_options, set_options,
generate_images, generate_images,
filter_images, filter_images,
@@ -53,6 +54,7 @@ MODELS_TO_OVERRIDE = {
"codeformer": "--codeformer-models-path", "codeformer": "--codeformer-models-path",
"embeddings": "--embeddings-dir", "embeddings": "--embeddings-dir",
"controlnet": "--controlnet-dir", "controlnet": "--controlnet-dir",
"text-encoder": "--text-encoder-dir",
} }
WEBUI_PATCHES = [ WEBUI_PATCHES = [

View File

@@ -27,6 +27,7 @@ webui_opts: dict = None
curr_models = { curr_models = {
"stable-diffusion": None, "stable-diffusion": None,
"vae": None, "vae": None,
"text-encoder": None,
} }
@@ -96,50 +97,51 @@ def ping(timeout=1):
def load_model(context, model_type, **kwargs): def load_model(context, model_type, **kwargs):
from easydiffusion.app import ROOT_DIR, getConfig
config = getConfig()
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
model_path = context.model_paths[model_type] model_path = context.model_paths[model_type]
if model_type == "stable-diffusion":
base_dir = os.path.join(models_dir, model_type)
model_path = os.path.relpath(model_path, base_dir)
# print(f"load model: {model_type=} {model_path=} {curr_models=}")
curr_models[model_type] = model_path
def unload_model(context, model_type, **kwargs):
# print(f"unload model: {model_type=} {curr_models=}")
curr_models[model_type] = None
def flush_model_changes(context):
if webui_opts is None: if webui_opts is None:
print("Server not ready, can't set the model") print("Server not ready, can't set the model")
return return
if model_type == "stable-diffusion": modules = []
model_name = os.path.basename(model_path) for model_type in ("vae", "text-encoder"):
model_name = os.path.splitext(model_name)[0] if curr_models[model_type]:
print(f"setting sd model: {model_name}") model_paths = curr_models[model_type]
if curr_models[model_type] != model_name: model_paths = [model_paths] if not isinstance(model_paths, list) else model_paths
try: modules += model_paths
res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name})
if res.status_code != 200:
raise Exception(res.text)
except Exception as e:
raise RuntimeError(
f"The engine failed to set the required options. Please check the logs in the command line window for more details."
)
curr_models[model_type] = model_name opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": modules}
elif model_type == "vae":
if curr_models[model_type] != model_path:
vae_model = [model_path] if model_path else []
opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model} print("Setting backend models", opts)
print("setting opts 2", opts)
try: try:
res = webui_post("/sdapi/v1/options", json=opts) res = webui_post("/sdapi/v1/options", json=opts)
if res.status_code != 200: print("got res", res.status_code)
raise Exception(res.text) if res.status_code != 200:
except Exception as e: raise Exception(res.text)
raise RuntimeError( except Exception as e:
f"The engine failed to set the required options. Please check the logs in the command line window for more details." raise RuntimeError(
) f"The engine failed to set the required options. Please check the logs in the command line window for more details."
)
curr_models[model_type] = model_path
def unload_model(context, model_type, **kwargs):
if model_type == "vae":
context.model_paths[model_type] = None
load_model(context, model_type)
def generate_images( def generate_images(
@@ -346,7 +348,7 @@ def refresh_models():
pass pass
try: try:
for type in ("checkpoints", "vae"): for type in ("checkpoints", "vae-and-text-encoders"):
t = Thread(target=make_refresh_call, args=(type,)) t = Thread(target=make_refresh_call, args=(type,))
t.start() t.start()
except Exception as e: except Exception as e:

View File

@@ -3,6 +3,7 @@ import shutil
from glob import glob from glob import glob
import traceback import traceback
from typing import Union from typing import Union
from os import path
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import ModelsData from easydiffusion.types import ModelsData
@@ -22,6 +23,7 @@ KNOWN_MODEL_TYPES = [
"codeformer", "codeformer",
"embeddings", "embeddings",
"controlnet", "controlnet",
"text-encoder",
] ]
MODEL_EXTENSIONS = { MODEL_EXTENSIONS = {
"stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"], "stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"],
@@ -33,6 +35,7 @@ MODEL_EXTENSIONS = {
"codeformer": [".pth"], "codeformer": [".pth"],
"embeddings": [".pt", ".bin", ".safetensors", ".sft"], "embeddings": [".pt", ".bin", ".safetensors", ".sft"],
"controlnet": [".pth", ".safetensors", ".sft"], "controlnet": [".pth", ".safetensors", ".sft"],
"text-encoder": [".safetensors", ".sft"],
} }
DEFAULT_MODELS = { DEFAULT_MODELS = {
"stable-diffusion": [ "stable-diffusion": [
@@ -59,6 +62,7 @@ ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility
"realesrgan": "RealESRGAN", "realesrgan": "RealESRGAN",
"lora": "Lora", "lora": "Lora",
"controlnet": "ControlNet", "controlnet": "ControlNet",
"text-encoder": "text_encoder",
} }
known_models = {} known_models = {}
@@ -204,6 +208,9 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models
context.model_load_errors = {} context.model_load_errors = {}
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
if hasattr(backend, "flush_model_changes"):
backend.flush_model_changes(context)
def resolve_model_paths(models_data: ModelsData): def resolve_model_paths(models_data: ModelsData):
from easydiffusion.backend_manager import backend from easydiffusion.backend_manager import backend
@@ -224,14 +231,25 @@ def resolve_model_paths(models_data: ModelsData):
for model_type in model_paths: for model_type in model_paths:
if model_type in skip_models: # doesn't use model paths if model_type in skip_models: # doesn't use model paths
continue continue
if model_type == "codeformer" and model_paths[model_type]:
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") if model_type in ("vae", "codeformer", "controlnet", "text-encoder") and model_paths[model_type]:
elif model_type == "controlnet" and model_paths[model_type]: model_ids = model_paths[model_type]
model_id = model_paths[model_type] model_ids = model_ids if isinstance(model_ids, list) else [model_ids]
model_info = get_model_info_from_db(model_type=model_type, model_id=model_id)
if model_info: new_model_paths = []
filename = model_info.get("url", "").split("/")[-1]
download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False) for model_id in model_ids:
log.info(f"Checking for {model_id=}")
model_info = get_model_info_from_db(model_type=model_type, model_id=model_id)
if model_info:
filename = model_info.get("url", "").split("/")[-1]
download_if_necessary(model_type, filename, model_id, skip_if_others_exist=False)
new_model_paths.append(path.splitext(filename)[0])
else: # not in the model db, probably a regular file
new_model_paths.append(model_id)
model_paths[model_type] = new_model_paths
model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type) model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
@@ -256,17 +274,31 @@ def download_default_models_if_necessary():
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
model_dir = get_model_dirs(model_type)[0] from easydiffusion.backend_manager import backend
model_path = os.path.join(model_dir, file_name)
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
other_models_exist = any_model_exists(model_type) and skip_if_others_exist other_models_exist = any_model_exists(model_type) and skip_if_others_exist
known_model_exists = os.path.exists(model_path)
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
if known_model_is_corrupt or (not other_models_exist and not known_model_exists): for model_dir in get_model_dirs(model_type):
print("> download", model_type, model_id) model_path = os.path.join(model_dir, file_name)
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False)
known_model_exists = os.path.exists(model_path)
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
needs_download = known_model_is_corrupt or (not other_models_exist and not known_model_exists)
log.info(f"{model_path=} {needs_download=}")
if known_model_exists:
log.info(f"{expected_hash=} {hash_file_quick(model_path)=}")
log.info(f"{known_model_is_corrupt=} {other_models_exist=} {known_model_exists=}")
if not needs_download:
return
print("> download", model_type, model_id)
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False)
backend.refresh_models()
def migrate_legacy_model_location(): def migrate_legacy_model_location():
@@ -363,7 +395,7 @@ def getModels(scan_for_malicious: bool = True):
models = { models = {
"options": { "options": {
"stable-diffusion": [], "stable-diffusion": [],
"vae": [], "vae": [{"ae": "ae (Flux VAE fp16)"}],
"hypernetwork": [], "hypernetwork": [],
"lora": [], "lora": [],
"codeformer": [{"codeformer": "CodeFormer"}], "codeformer": [{"codeformer": "CodeFormer"}],
@@ -383,6 +415,11 @@ def getModels(scan_for_malicious: bool = True):
# {"control_v11e_sd15_shuffle": "Shuffle"}, # {"control_v11e_sd15_shuffle": "Shuffle"},
# {"control_v11f1e_sd15_tile": "Tile"}, # {"control_v11f1e_sd15_tile": "Tile"},
], ],
"text-encoder": [
{"t5xxl_fp16": "T5 XXL fp16"},
{"clip_l": "CLIP L"},
{"clip_g": "CLIP G"},
],
}, },
} }
@@ -466,6 +503,7 @@ def getModels(scan_for_malicious: bool = True):
listModels(model_type="lora") listModels(model_type="lora")
listModels(model_type="embeddings", nameFilter=get_embedding_token) listModels(model_type="embeddings", nameFilter=get_embedding_token)
listModels(model_type="controlnet") listModels(model_type="controlnet")
listModels(model_type="text-encoder")
if scan_for_malicious and models_scanned > 0: if scan_for_malicious and models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")

View File

@@ -80,6 +80,7 @@ class RenderTaskData(TaskData):
latent_upscaler_steps: int = 10 latent_upscaler_steps: int = 10
use_stable_diffusion_model: Union[str, List[str]] = "sd-v1-4" use_stable_diffusion_model: Union[str, List[str]] = "sd-v1-4"
use_vae_model: Union[str, List[str]] = None use_vae_model: Union[str, List[str]] = None
use_text_encoder_model: Union[str, List[str]] = None
use_hypernetwork_model: Union[str, List[str]] = None use_hypernetwork_model: Union[str, List[str]] = None
use_lora_model: Union[str, List[str]] = None use_lora_model: Union[str, List[str]] = None
use_controlnet_model: Union[str, List[str]] = None use_controlnet_model: Union[str, List[str]] = None
@@ -211,6 +212,7 @@ def convert_legacy_render_req_to_new(old_req: dict):
# move the model info # move the model info
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model") model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
model_paths["vae"] = old_req.get("use_vae_model") model_paths["vae"] = old_req.get("use_vae_model")
model_paths["text-encoder"] = old_req.get("use_text_encoder_model")
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model") model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
model_paths["lora"] = old_req.get("use_lora_model") model_paths["lora"] = old_req.get("use_lora_model")
model_paths["controlnet"] = old_req.get("use_controlnet_model") model_paths["controlnet"] = old_req.get("use_controlnet_model")

View File

@@ -302,6 +302,14 @@
<input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" /> <input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
<a href="https://github.com/easydiffusion/easydiffusion/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
</td></tr> </td></tr>
<tr id="text_encoder_model_container" class="pl-5 gated-feature" data-feature-keys="backend_webui">
<td>
<label for="text_encoder_model">Text Encoder:</label>
</td>
<td>
<div id="text_encoder_model" data-path=""></div>
</td>
</tr>
<tr id="samplerSelection" class="pl-5"><td><label for="sampler_name">Sampler:</label></td><td> <tr id="samplerSelection" class="pl-5"><td><label for="sampler_name">Sampler:</label></td><td>
<select id="sampler_name" name="sampler_name"> <select id="sampler_name" name="sampler_name">
<option value="plms">PLMS</option> <option value="plms">PLMS</option>
@@ -442,7 +450,7 @@
<tr class="pl-5"><td><label for="guidance_scale_slider">Guidance Scale:</label></td><td> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="11" max="500"> <input id="guidance_scale" name="guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr> <tr class="pl-5"><td><label for="guidance_scale_slider">Guidance Scale:</label></td><td> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="11" max="500"> <input id="guidance_scale" name="guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
<tr class="pl-5 displayNone warning-label" id="guidanceWarning"><td></td><td id="guidanceWarningText"></td></tr> <tr class="pl-5 displayNone warning-label" id="guidanceWarning"><td></td><td id="guidanceWarningText"></td></tr>
<tr id="prompt_strength_container" class="pl-5"><td><label for="prompt_strength_slider">Prompt Strength:</label></td><td> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td></tr> <tr id="prompt_strength_container" class="pl-5"><td><label for="prompt_strength_slider">Prompt Strength:</label></td><td> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td></tr>
<tr id="distilled_guidance_scale_container" class="pl-5 displayNone"><td><label for="distilled_guidance_scale_slider">Distilled Guidance:</label></td><td> <input id="distilled_guidance_scale_slider" name="distilled_guidance_scale_slider" class="editor-slider" value="35" type="range" min="11" max="500"> <input id="distilled_guidance_scale" name="distilled_guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr> <tr id="distilled_guidance_scale_container" class="pl-5 gated-feature" data-feature-keys="backend_webui"><td><label for="distilled_guidance_scale_slider">Distilled Guidance:</label></td><td> <input id="distilled_guidance_scale_slider" name="distilled_guidance_scale_slider" class="editor-slider" value="35" type="range" min="11" max="500"> <input id="distilled_guidance_scale" name="distilled_guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
<tr id="lora_model_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers backend_webui"> <tr id="lora_model_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">
<td> <td>
<label for="lora_model">LoRA:</label> <label for="lora_model">LoRA:</label>

View File

@@ -60,6 +60,7 @@ const SETTINGS_IDS_LIST = [
"extract_lora_from_prompt", "extract_lora_from_prompt",
"embedding-card-size-selector", "embedding-card-size-selector",
"lora_model", "lora_model",
"text_encoder_model",
"enable_vae_tiling", "enable_vae_tiling",
"controlnet_alpha", "controlnet_alpha",
] ]

View File

@@ -394,6 +394,45 @@ const TASK_MAPPING = {
return val return val
}, },
}, },
use_text_encoder_model: {
name: "Text Encoder model",
setUI: (use_text_encoder_model) => {
let modelPaths = []
use_text_encoder_model = use_text_encoder_model === null ? "" : use_text_encoder_model
use_text_encoder_model = Array.isArray(use_text_encoder_model) ? use_text_encoder_model : [use_text_encoder_model]
use_text_encoder_model.forEach((m) => {
if (m.includes("models\\text-encoder\\")) {
m = m.split("models\\text-encoder\\")[1]
} else if (m.includes("models\\\\text-encoder\\\\")) {
m = m.split("models\\\\text-encoder\\\\")[1]
} else if (m.includes("models/text-encoder/")) {
m = m.split("models/text-encoder/")[1]
}
m = m.replaceAll("\\\\", "/")
m = getModelPath(m, [".safetensors", ".sft"])
modelPaths.push(m)
})
text_encoderModelField.modelNames = modelPaths
},
readUI: () => {
return text_encoderModelField.modelNames
},
parse: (val) => {
val = !val || val === "None" ? "" : val
if (typeof val === "string" && val.includes(",")) {
val = val.split(",")
val = val.map((v) => v.trim())
val = val.map((v) => v.replaceAll("\\", "\\\\"))
val = val.map((v) => v.replaceAll('"', ""))
val = val.map((v) => v.replaceAll("'", ""))
val = val.map((v) => '"' + v + '"')
val = "[" + val + "]"
val = JSON.parse(val)
}
val = Array.isArray(val) ? val : [val]
return val
},
},
use_hypernetwork_model: { use_hypernetwork_model: {
name: "Hypernetwork model", name: "Hypernetwork model",
setUI: (use_hypernetwork_model) => { setUI: (use_hypernetwork_model) => {
@@ -620,6 +659,7 @@ const TASK_TEXT_MAPPING = {
hypernetwork_strength: "Hypernetwork Strength", hypernetwork_strength: "Hypernetwork Strength",
use_lora_model: "LoRA model", use_lora_model: "LoRA model",
lora_alpha: "LoRA Strength", lora_alpha: "LoRA Strength",
use_text_encoder_model: "Text Encoder model",
use_controlnet_model: "ControlNet model", use_controlnet_model: "ControlNet model",
control_filter_to_apply: "ControlNet Filter", control_filter_to_apply: "ControlNet Filter",
control_alpha: "ControlNet Strength", control_alpha: "ControlNet Strength",

View File

@@ -54,6 +54,7 @@ const taskConfigSetup = {
label: "Hypernetwork Strength", label: "Hypernetwork Strength",
visible: ({ reqBody }) => !!reqBody?.use_hypernetwork_model, visible: ({ reqBody }) => !!reqBody?.use_hypernetwork_model,
}, },
use_text_encoder_model: { label: "Text Encoder", visible: ({ reqBody }) => !!reqBody?.use_text_encoder_model },
use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model }, use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model }, lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
preserve_init_image_color_profile: "Preserve Color Profile", preserve_init_image_color_profile: "Preserve Color Profile",
@@ -141,6 +142,7 @@ let tilingField = document.querySelector("#tiling")
let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false) let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false)
let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None") let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None")
let loraModelField = new MultiModelSelector(document.querySelector("#lora_model"), "lora", "LoRA", 0.5, 0.02) let loraModelField = new MultiModelSelector(document.querySelector("#lora_model"), "lora", "LoRA", 0.5, 0.02)
let textEncoderModelField = new MultiModelSelector(document.querySelector("#text_encoder_model"), "text-encoder", "Text Encoder", 0.5, 0.02, false)
let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None") let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None")
let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider") let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider")
let hypernetworkStrengthField = document.querySelector("#hypernetwork_strength") let hypernetworkStrengthField = document.querySelector("#hypernetwork_strength")
@@ -1396,6 +1398,7 @@ function getCurrentUserRequest() {
newTask.reqBody.hypernetwork_strength = parseFloat(hypernetworkStrengthField.value) newTask.reqBody.hypernetwork_strength = parseFloat(hypernetworkStrengthField.value)
} }
if (testDiffusers.checked) { if (testDiffusers.checked) {
// lora
let loraModelData = loraModelField.value let loraModelData = loraModelField.value
let modelNames = loraModelData["modelNames"] let modelNames = loraModelData["modelNames"]
let modelStrengths = loraModelData["modelWeights"] let modelStrengths = loraModelData["modelWeights"]
@@ -1408,6 +1411,16 @@ function getCurrentUserRequest() {
newTask.reqBody.lora_alpha = modelStrengths newTask.reqBody.lora_alpha = modelStrengths
} }
// text encoder
let textEncoderModelNames = textEncoderModelField.modelNames
if (textEncoderModelNames.length > 0) {
textEncoderModelNames = textEncoderModelNames.length == 1 ? textEncoderModelNames[0] : textEncoderModelNames
newTask.reqBody.use_text_encoder_model = textEncoderModelNames
}
// vae tiling
if (tilingField.value !== "none") { if (tilingField.value !== "none") {
newTask.reqBody.tiling = tilingField.value newTask.reqBody.tiling = tilingField.value
} }
@@ -1891,8 +1904,31 @@ document.addEventListener("refreshModels", function() {
onControlnetModelChange() onControlnetModelChange()
}) })
// tip for Flux // utilities for Flux and Chroma
let sdModelField = document.querySelector("#stable_diffusion_model") let sdModelField = document.querySelector("#stable_diffusion_model")
// function checkAndSetDependentModels() {
// let sdModel = sdModelField.value.toLowerCase()
// let isFlux = sdModel.includes("flux")
// let isChroma = sdModel.includes("chroma")
// if (isFlux || isChroma) {
// vaeModelField.value = "ae"
// if (isFlux) {
// textEncoderModelField.modelNames = ["t5xxl_fp16", "clip_l"]
// } else {
// textEncoderModelField.modelNames = ["t5xxl_fp16"]
// }
// } else {
// if (vaeModelField.value == "ae") {
// vaeModelField.value = ""
// }
// textEncoderModelField.modelNames = []
// }
// }
// sdModelField.addEventListener("change", checkAndSetDependentModels)
function checkGuidanceValue() { function checkGuidanceValue() {
let guidance = parseFloat(guidanceScaleField.value) let guidance = parseFloat(guidanceScaleField.value)
let guidanceWarning = document.querySelector("#guidanceWarning") let guidanceWarning = document.querySelector("#guidanceWarning")
@@ -1917,15 +1953,16 @@ sdModelField.addEventListener("change", checkGuidanceValue)
guidanceScaleField.addEventListener("change", checkGuidanceValue) guidanceScaleField.addEventListener("change", checkGuidanceValue)
guidanceScaleSlider.addEventListener("change", checkGuidanceValue) guidanceScaleSlider.addEventListener("change", checkGuidanceValue)
function checkGuidanceScaleVisibility() { // disabling until we can detect flux models more reliably
let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container") // function checkGuidanceScaleVisibility() {
if (sdModelField.value.toLowerCase().includes("flux")) { // let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container")
guidanceScaleContainer.classList.remove("displayNone") // if (sdModelField.value.toLowerCase().includes("flux")) {
} else { // guidanceScaleContainer.classList.remove("displayNone")
guidanceScaleContainer.classList.add("displayNone") // } else {
} // guidanceScaleContainer.classList.add("displayNone")
} // }
sdModelField.addEventListener("change", checkGuidanceScaleVisibility) // }
// sdModelField.addEventListener("change", checkGuidanceScaleVisibility)
function checkFluxSampler() { function checkFluxSampler() {
let samplerWarning = document.querySelector("#fluxSamplerWarning") let samplerWarning = document.querySelector("#fluxSamplerWarning")
@@ -1980,8 +2017,9 @@ schedulerField.addEventListener("change", checkFluxSchedulerSteps)
numInferenceStepsField.addEventListener("change", checkFluxSchedulerSteps) numInferenceStepsField.addEventListener("change", checkFluxSchedulerSteps)
document.addEventListener("refreshModels", function() { document.addEventListener("refreshModels", function() {
// checkAndSetDependentModels()
checkGuidanceValue() checkGuidanceValue()
checkGuidanceScaleVisibility() // checkGuidanceScaleVisibility()
checkFluxSampler() checkFluxSampler()
checkFluxScheduler() checkFluxScheduler()
checkFluxSchedulerSteps() checkFluxSchedulerSteps()

View File

@@ -10,6 +10,7 @@ class MultiModelSelector {
root root
modelType modelType
modelNameFriendly modelNameFriendly
showWeights
defaultWeight defaultWeight
weightStep weightStep
@@ -35,13 +36,13 @@ class MultiModelSelector {
if (typeof modelData !== "object") { if (typeof modelData !== "object") {
throw new Error("Multi-model selector expects an object containing modelNames and modelWeights as keys!") throw new Error("Multi-model selector expects an object containing modelNames and modelWeights as keys!")
} }
if (!("modelNames" in modelData) || !("modelWeights" in modelData)) { if (!("modelNames" in modelData) || (this.showWeights && !("modelWeights" in modelData))) {
throw new Error("modelNames or modelWeights not present in the data passed to the multi-model selector") throw new Error("modelNames or modelWeights not present in the data passed to the multi-model selector")
} }
let newModelNames = modelData["modelNames"] let newModelNames = modelData["modelNames"]
let newModelWeights = modelData["modelWeights"] let newModelWeights = modelData["modelWeights"]
if (newModelNames.length !== newModelWeights.length) { if (newModelWeights && newModelNames.length !== newModelWeights.length) {
throw new Error("Need to pass an equal number of modelNames and modelWeights!") throw new Error("Need to pass an equal number of modelNames and modelWeights!")
} }
@@ -50,7 +51,7 @@ class MultiModelSelector {
// the root of all this unholiness is because searchable-models automatically dispatches an update event // the root of all this unholiness is because searchable-models automatically dispatches an update event
// as soon as the value is updated via JS, which is against the DOM pattern of not dispatching an event automatically // as soon as the value is updated via JS, which is against the DOM pattern of not dispatching an event automatically
// unless the caller explicitly dispatches the event. // unless the caller explicitly dispatches the event.
this.modelWeights = newModelWeights this.modelWeights = newModelWeights || []
this.modelNames = newModelNames this.modelNames = newModelNames
} }
get disabled() { get disabled() {
@@ -91,10 +92,11 @@ class MultiModelSelector {
} }
} }
constructor(root, modelType, modelNameFriendly = undefined, defaultWeight = 0.5, weightStep = 0.02) { constructor(root, modelType, modelNameFriendly = undefined, defaultWeight = 0.5, weightStep = 0.02, showWeights = true) {
this.root = root this.root = root
this.modelType = modelType this.modelType = modelType
this.modelNameFriendly = modelNameFriendly || modelType this.modelNameFriendly = modelNameFriendly || modelType
this.showWeights = showWeights
this.defaultWeight = defaultWeight this.defaultWeight = defaultWeight
this.weightStep = weightStep this.weightStep = weightStep
@@ -135,10 +137,13 @@ class MultiModelSelector {
const modelElement = document.createElement("div") const modelElement = document.createElement("div")
modelElement.className = "model_entry" modelElement.className = "model_entry"
modelElement.innerHTML = ` let html = `<input id="${this.modelType}_${idx}" class="model_name model-filter" type="text" spellcheck="false" autocomplete="off" data-path="" />`
<input id="${this.modelType}_${idx}" class="model_name model-filter" type="text" spellcheck="false" autocomplete="off" data-path="" />
<input class="model_weight" type="number" step="${this.weightStep}" value="${this.defaultWeight}" pattern="^-?[0-9]*\.?[0-9]*$" onkeypress="preventNonNumericalInput(event)"> if (this.showWeights) {
` html += `<input class="model_weight" type="number" step="${this.weightStep}" value="${this.defaultWeight}" pattern="^-?[0-9]*\.?[0-9]*$" onkeypress="preventNonNumericalInput(event)">`
}
modelElement.innerHTML = html
this.modelContainer.appendChild(modelElement) this.modelContainer.appendChild(modelElement)
let modelNameEl = modelElement.querySelector(".model_name") let modelNameEl = modelElement.querySelector(".model_name")
@@ -160,8 +165,8 @@ class MultiModelSelector {
modelNameEl.addEventListener("change", makeUpdateEvent("change")) modelNameEl.addEventListener("change", makeUpdateEvent("change"))
modelNameEl.addEventListener("input", makeUpdateEvent("input")) modelNameEl.addEventListener("input", makeUpdateEvent("input"))
modelWeightEl.addEventListener("change", makeUpdateEvent("change")) modelWeightEl?.addEventListener("change", makeUpdateEvent("change"))
modelWeightEl.addEventListener("input", makeUpdateEvent("input")) modelWeightEl?.addEventListener("input", makeUpdateEvent("input"))
let removeBtn = document.createElement("button") let removeBtn = document.createElement("button")
removeBtn.className = "remove_model_btn" removeBtn.className = "remove_model_btn"
@@ -218,10 +223,14 @@ class MultiModelSelector {
} }
get modelWeights() { get modelWeights() {
return this.getModelElements(true).map((e) => e.weight.value) return this.getModelElements(true).map((e) => e.weight?.value)
} }
set modelWeights(newModelWeights) { set modelWeights(newModelWeights) {
if (!this.showWeights) {
return
}
this.resizeEntryList(newModelWeights.length) this.resizeEntryList(newModelWeights.length)
if (newModelWeights.length === 0) { if (newModelWeights.length === 0) {